Browse Source

fix: ongoing chat stop issue

Timothy Jaeryang Baek 3 months ago
parent
commit
f3fe82da80

+ 19 - 3
backend/open_webui/main.py

@@ -372,7 +372,11 @@ from open_webui.utils.auth import (
 from open_webui.utils.oauth import OAuthManager
 from open_webui.utils.oauth import OAuthManager
 from open_webui.utils.security_headers import SecurityHeadersMiddleware
 from open_webui.utils.security_headers import SecurityHeadersMiddleware
 
 
-from open_webui.tasks import stop_task, list_tasks  # Import from tasks.py
+from open_webui.tasks import (
+    list_task_ids_by_chat_id,
+    stop_task,
+    list_tasks,
+)  # Import from tasks.py
 
 
 from open_webui.utils.redis import get_sentinels_from_env
 from open_webui.utils.redis import get_sentinels_from_env
 
 
@@ -1196,7 +1200,7 @@ async def chat_action(
 @app.post("/api/tasks/stop/{task_id}")
 @app.post("/api/tasks/stop/{task_id}")
 async def stop_task_endpoint(task_id: str, user=Depends(get_verified_user)):
 async def stop_task_endpoint(task_id: str, user=Depends(get_verified_user)):
     try:
     try:
-        result = await stop_task(task_id)  # Use the function from tasks.py
+        result = await stop_task(task_id)
         return result
         return result
     except ValueError as e:
     except ValueError as e:
         raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
         raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
@@ -1204,7 +1208,19 @@ async def stop_task_endpoint(task_id: str, user=Depends(get_verified_user)):
 
 
 @app.get("/api/tasks")
 @app.get("/api/tasks")
 async def list_tasks_endpoint(user=Depends(get_verified_user)):
 async def list_tasks_endpoint(user=Depends(get_verified_user)):
-    return {"tasks": list_tasks()}  # Use the function from tasks.py
+    return {"tasks": list_tasks()}
+
+
+@app.get("/api/tasks/chat/{chat_id}")
+async def list_tasks_by_chat_id_endpoint(chat_id: str, user=Depends(get_verified_user)):
+    chat = Chats.get_chat_by_id(chat_id)
+    if chat is None or chat.user_id != user.id:
+        return {"task_ids": []}
+
+    task_ids = list_task_ids_by_chat_id(chat_id)
+
+    print(f"Task IDs for chat {chat_id}: {task_ids}")
+    return {"task_ids": task_ids}
 
 
 
 
 ##################################
 ##################################

+ 24 - 4
backend/open_webui/tasks.py

@@ -5,16 +5,23 @@ from uuid import uuid4
 
 
 # A dictionary to keep track of active tasks
 # A dictionary to keep track of active tasks
 tasks: Dict[str, asyncio.Task] = {}
 tasks: Dict[str, asyncio.Task] = {}
+chat_tasks = {}
 
 
 
 
-def cleanup_task(task_id: str):
+def cleanup_task(task_id: str, id=None):
     """
     """
     Remove a completed or canceled task from the global `tasks` dictionary.
     Remove a completed or canceled task from the global `tasks` dictionary.
     """
     """
     tasks.pop(task_id, None)  # Remove the task if it exists
     tasks.pop(task_id, None)  # Remove the task if it exists
 
 
+    # If an ID is provided, remove the task from the chat_tasks dictionary
+    if id and task_id in chat_tasks.get(id, []):
+        chat_tasks[id].remove(task_id)
+        if not chat_tasks[id]:  # If no tasks left for this ID, remove the entry
+            chat_tasks.pop(id, None)
 
 
-def create_task(coroutine):
+
+def create_task(coroutine, id=None):
     """
     """
     Create a new asyncio task and add it to the global task dictionary.
     Create a new asyncio task and add it to the global task dictionary.
     """
     """
@@ -22,9 +29,15 @@ def create_task(coroutine):
     task = asyncio.create_task(coroutine)  # Create the task
     task = asyncio.create_task(coroutine)  # Create the task
 
 
     # Add a done callback for cleanup
     # Add a done callback for cleanup
-    task.add_done_callback(lambda t: cleanup_task(task_id))
-
+    task.add_done_callback(lambda t: cleanup_task(task_id, id))
     tasks[task_id] = task
     tasks[task_id] = task
+
+    # If an ID is provided, associate the task with that ID
+    if chat_tasks.get(id):
+        chat_tasks[id].append(task_id)
+    else:
+        chat_tasks[id] = [task_id]
+
     return task_id, task
     return task_id, task
 
 
 
 
@@ -42,6 +55,13 @@ def list_tasks():
     return list(tasks.keys())
     return list(tasks.keys())
 
 
 
 
+def list_task_ids_by_chat_id(id):
+    """
+    List all tasks associated with a specific ID.
+    """
+    return chat_tasks.get(id, [])
+
+
 async def stop_task(task_id: str):
 async def stop_task(task_id: str):
     """
     """
     Cancel a running task and remove it from the global task list.
     Cancel a running task and remove it from the global task list.

+ 3 - 1
backend/open_webui/utils/middleware.py

@@ -2245,7 +2245,9 @@ async def process_chat_response(
                 await response.background()
                 await response.background()
 
 
         # background_tasks.add_task(post_response_handler, response, events)
         # background_tasks.add_task(post_response_handler, response, events)
-        task_id, _ = create_task(post_response_handler(response, events))
+        task_id, _ = create_task(
+            post_response_handler(response, events), id=metadata["chat_id"]
+        )
         return {"status": True, "task_id": task_id}
         return {"status": True, "task_id": task_id}
 
 
     else:
     else:

+ 32 - 0
src/lib/apis/index.ts

@@ -260,6 +260,38 @@ export const stopTask = async (token: string, id: string) => {
 	return res;
 	return res;
 };
 };
 
 
+export const getTaskIdsByChatId = async (token: string, chat_id: string) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_BASE_URL}/api/tasks/chat/${chat_id}`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			...(token && { authorization: `Bearer ${token}` })
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			if ('detail' in err) {
+				error = err.detail;
+			} else {
+				error = err;
+			}
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
 export const getToolServerData = async (token: string, url: string) => {
 export const getToolServerData = async (token: string, url: string) => {
 	let error = null;
 	let error = null;
 
 

+ 11 - 3
src/lib/components/chat/Chat.svelte

@@ -74,7 +74,8 @@
 		generateQueries,
 		generateQueries,
 		chatAction,
 		chatAction,
 		generateMoACompletion,
 		generateMoACompletion,
-		stopTask
+		stopTask,
+		getTaskIdsByChatId
 	} from '$lib/apis';
 	} from '$lib/apis';
 	import { getTools } from '$lib/apis/tools';
 	import { getTools } from '$lib/apis/tools';
 
 
@@ -825,7 +826,14 @@
 					}
 					}
 				}
 				}
 
 
-				taskIds = chat?.task_ids ?? null;
+				const taskRes = await getTaskIdsByChatId(localStorage.token, $chatId).catch((error) => {
+					return null;
+				});
+
+				if (taskRes) {
+					taskIds = taskRes.task_ids;
+				}
+
 				await tick();
 				await tick();
 
 
 				return true;
 				return true;
@@ -1721,7 +1729,6 @@
 			taskIds = null;
 			taskIds = null;
 
 
 			const responseMessage = history.messages[history.currentId];
 			const responseMessage = history.messages[history.currentId];
-
 			// Set all response messages to done
 			// Set all response messages to done
 			for (const messageId of history.messages[responseMessage.parentId].childrenIds) {
 			for (const messageId of history.messages[responseMessage.parentId].childrenIds) {
 				history.messages[messageId].done = true;
 				history.messages[messageId].done = true;
@@ -2014,6 +2021,7 @@
 						<div class=" pb-[1rem]">
 						<div class=" pb-[1rem]">
 							<MessageInput
 							<MessageInput
 								{history}
 								{history}
+								{taskIds}
 								{selectedModels}
 								{selectedModels}
 								bind:files
 								bind:files
 								bind:prompt
 								bind:prompt

+ 87 - 90
src/lib/components/chat/MessageInput.svelte

@@ -71,6 +71,7 @@
 	$: selectedModelIds = atSelectedModel !== undefined ? [atSelectedModel.id] : selectedModels;
 	$: selectedModelIds = atSelectedModel !== undefined ? [atSelectedModel.id] : selectedModels;
 
 
 	export let history;
 	export let history;
+	export let taskIds = null;
 
 
 	export let prompt = '';
 	export let prompt = '';
 	export let files = [];
 	export let files = [];
@@ -1237,116 +1238,112 @@
 											</Tooltip>
 											</Tooltip>
 										{/if}
 										{/if}
 
 
-										{#if !history.currentId || history.messages[history.currentId]?.done == true}
-											{#if prompt === '' && files.length === 0}
-												<div class=" flex items-center">
-													<Tooltip content={$i18n.t('Call')}>
-														<button
-															class=" bg-black text-white hover:bg-gray-900 dark:bg-white dark:text-black dark:hover:bg-gray-100 transition rounded-full p-1.5 self-center"
-															type="button"
-															on:click={async () => {
-																if (selectedModels.length > 1) {
-																	toast.error($i18n.t('Select only one model to call'));
+										{#if taskIds && taskIds.length > 0}
+											<div class=" flex items-center">
+												<Tooltip content={$i18n.t('Stop')}>
+													<button
+														class="bg-white hover:bg-gray-100 text-gray-800 dark:bg-gray-700 dark:text-white dark:hover:bg-gray-800 transition rounded-full p-1.5"
+														on:click={() => {
+															stopResponse();
+														}}
+													>
+														<svg
+															xmlns="http://www.w3.org/2000/svg"
+															viewBox="0 0 24 24"
+															fill="currentColor"
+															class="size-5"
+														>
+															<path
+																fill-rule="evenodd"
+																d="M2.25 12c0-5.385 4.365-9.75 9.75-9.75s9.75 4.365 9.75 9.75-4.365 9.75-9.75 9.75S2.25 17.385 2.25 12zm6-2.438c0-.724.588-1.312 1.313-1.312h4.874c.725 0 1.313.588 1.313 1.313v4.874c0 .725-.588 1.313-1.313 1.313H9.564a1.312 1.312 0 01-1.313-1.313V9.564z"
+																clip-rule="evenodd"
+															/>
+														</svg>
+													</button>
+												</Tooltip>
+											</div>
+										{:else if prompt === '' && files.length === 0}
+											<div class=" flex items-center">
+												<Tooltip content={$i18n.t('Call')}>
+													<button
+														class=" bg-black text-white hover:bg-gray-900 dark:bg-white dark:text-black dark:hover:bg-gray-100 transition rounded-full p-1.5 self-center"
+														type="button"
+														on:click={async () => {
+															if (selectedModels.length > 1) {
+																toast.error($i18n.t('Select only one model to call'));
 
 
-																	return;
-																}
+																return;
+															}
 
 
-																if ($config.audio.stt.engine === 'web') {
-																	toast.error(
-																		$i18n.t(
-																			'Call feature is not supported when using Web STT engine'
-																		)
-																	);
+															if ($config.audio.stt.engine === 'web') {
+																toast.error(
+																	$i18n.t('Call feature is not supported when using Web STT engine')
+																);
 
 
-																	return;
-																}
-																// check if user has access to getUserMedia
-																try {
-																	let stream = await navigator.mediaDevices.getUserMedia({
-																		audio: true
-																	});
-																	// If the user grants the permission, proceed to show the call overlay
+																return;
+															}
+															// check if user has access to getUserMedia
+															try {
+																let stream = await navigator.mediaDevices.getUserMedia({
+																	audio: true
+																});
+																// If the user grants the permission, proceed to show the call overlay
 
 
-																	if (stream) {
-																		const tracks = stream.getTracks();
-																		tracks.forEach((track) => track.stop());
-																	}
+																if (stream) {
+																	const tracks = stream.getTracks();
+																	tracks.forEach((track) => track.stop());
+																}
 
 
-																	stream = null;
+																stream = null;
 
 
-																	if ($settings.audio?.tts?.engine === 'browser-kokoro') {
-																		// If the user has not initialized the TTS worker, initialize it
-																		if (!$TTSWorker) {
-																			await TTSWorker.set(
-																				new KokoroWorker({
-																					dtype: $settings.audio?.tts?.engineConfig?.dtype ?? 'fp32'
-																				})
-																			);
+																if ($settings.audio?.tts?.engine === 'browser-kokoro') {
+																	// If the user has not initialized the TTS worker, initialize it
+																	if (!$TTSWorker) {
+																		await TTSWorker.set(
+																			new KokoroWorker({
+																				dtype: $settings.audio?.tts?.engineConfig?.dtype ?? 'fp32'
+																			})
+																		);
 
 
-																			await $TTSWorker.init();
-																		}
+																		await $TTSWorker.init();
 																	}
 																	}
-
-																	showCallOverlay.set(true);
-																	showControls.set(true);
-																} catch (err) {
-																	// If the user denies the permission or an error occurs, show an error message
-																	toast.error(
-																		$i18n.t('Permission denied when accessing media devices')
-																	);
 																}
 																}
-															}}
-															aria-label="Call"
-														>
-															<Headphone className="size-5" />
-														</button>
-													</Tooltip>
-												</div>
-											{:else}
-												<div class=" flex items-center">
-													<Tooltip content={$i18n.t('Send message')}>
-														<button
-															id="send-message-button"
-															class="{!(prompt === '' && files.length === 0)
-																? 'bg-black text-white hover:bg-gray-900 dark:bg-white dark:text-black dark:hover:bg-gray-100 '
-																: 'text-white bg-gray-200 dark:text-gray-900 dark:bg-gray-700 disabled'} transition rounded-full p-1.5 self-center"
-															type="submit"
-															disabled={prompt === '' && files.length === 0}
-														>
-															<svg
-																xmlns="http://www.w3.org/2000/svg"
-																viewBox="0 0 16 16"
-																fill="currentColor"
-																class="size-5"
-															>
-																<path
-																	fill-rule="evenodd"
-																	d="M8 14a.75.75 0 0 1-.75-.75V4.56L4.03 7.78a.75.75 0 0 1-1.06-1.06l4.5-4.5a.75.75 0 0 1 1.06 0l4.5 4.5a.75.75 0 0 1-1.06 1.06L8.75 4.56v8.69A.75.75 0 0 1 8 14Z"
-																	clip-rule="evenodd"
-																/>
-															</svg>
-														</button>
-													</Tooltip>
-												</div>
-											{/if}
+
+																showCallOverlay.set(true);
+																showControls.set(true);
+															} catch (err) {
+																// If the user denies the permission or an error occurs, show an error message
+																toast.error(
+																	$i18n.t('Permission denied when accessing media devices')
+																);
+															}
+														}}
+														aria-label="Call"
+													>
+														<Headphone className="size-5" />
+													</button>
+												</Tooltip>
+											</div>
 										{:else}
 										{:else}
 											<div class=" flex items-center">
 											<div class=" flex items-center">
-												<Tooltip content={$i18n.t('Stop')}>
+												<Tooltip content={$i18n.t('Send message')}>
 													<button
 													<button
-														class="bg-white hover:bg-gray-100 text-gray-800 dark:bg-gray-700 dark:text-white dark:hover:bg-gray-800 transition rounded-full p-1.5"
-														on:click={() => {
-															stopResponse();
-														}}
+														id="send-message-button"
+														class="{!(prompt === '' && files.length === 0)
+															? 'bg-black text-white hover:bg-gray-900 dark:bg-white dark:text-black dark:hover:bg-gray-100 '
+															: 'text-white bg-gray-200 dark:text-gray-900 dark:bg-gray-700 disabled'} transition rounded-full p-1.5 self-center"
+														type="submit"
+														disabled={prompt === '' && files.length === 0}
 													>
 													>
 														<svg
 														<svg
 															xmlns="http://www.w3.org/2000/svg"
 															xmlns="http://www.w3.org/2000/svg"
-															viewBox="0 0 24 24"
+															viewBox="0 0 16 16"
 															fill="currentColor"
 															fill="currentColor"
 															class="size-5"
 															class="size-5"
 														>
 														>
 															<path
 															<path
 																fill-rule="evenodd"
 																fill-rule="evenodd"
-																d="M2.25 12c0-5.385 4.365-9.75 9.75-9.75s9.75 4.365 9.75 9.75-4.365 9.75-9.75 9.75S2.25 17.385 2.25 12zm6-2.438c0-.724.588-1.312 1.313-1.312h4.874c.725 0 1.313.588 1.313 1.313v4.874c0 .725-.588 1.313-1.313 1.313H9.564a1.312 1.312 0 01-1.313-1.313V9.564z"
+																d="M8 14a.75.75 0 0 1-.75-.75V4.56L4.03 7.78a.75.75 0 0 1-1.06-1.06l4.5-4.5a.75.75 0 0 1 1.06 0l4.5 4.5a.75.75 0 0 1-1.06 1.06L8.75 4.56v8.69A.75.75 0 0 1 8 14Z"
 																clip-rule="evenodd"
 																clip-rule="evenodd"
 															/>
 															/>
 														</svg>
 														</svg>