Ver Fonte

refac/enh: async process chat handling

Timothy Jaeryang Baek há 1 mês atrás
pai
commit
d6f709574e
2 ficheiros alterados com 63 adições e 55 exclusões
  1. 62 47
      backend/open_webui/main.py
  2. 1 8
      backend/open_webui/utils/middleware.py

+ 62 - 47
backend/open_webui/main.py

@@ -57,6 +57,7 @@ from open_webui.utils.logger import start_logger
 from open_webui.socket.main import (
     app as socket_app,
     periodic_usage_pool_cleanup,
+    get_event_emitter,
     get_models_in_use,
     get_active_user_ids,
 )
@@ -466,6 +467,7 @@ from open_webui.utils.redis import get_redis_connection
 from open_webui.tasks import (
     redis_task_command_listener,
     list_task_ids_by_item_id,
+    create_task,
     stop_task,
     list_tasks,
 )  # Import from tasks.py
@@ -1473,65 +1475,78 @@ async def chat_completion(
         request.state.metadata = metadata
         form_data["metadata"] = metadata
 
-        form_data, metadata, events = await process_chat_payload(
-            request, form_data, user, metadata, model
-        )
     except Exception as e:
-        log.debug(f"Error processing chat payload: {e}")
-        if metadata.get("chat_id") and metadata.get("message_id"):
-            # Update the chat message with the error
-            try:
-                Chats.upsert_message_to_chat_by_id_and_message_id(
-                    metadata["chat_id"],
-                    metadata["message_id"],
-                    {
-                        "error": {"content": str(e)},
-                    },
-                )
-            except:
-                pass
-
+        log.debug(f"Error processing chat metadata: {e}")
         raise HTTPException(
             status_code=status.HTTP_400_BAD_REQUEST,
             detail=str(e),
         )
 
-    try:
-        response = await chat_completion_handler(request, form_data, user)
-        if metadata.get("chat_id") and metadata.get("message_id"):
-            try:
-                Chats.upsert_message_to_chat_by_id_and_message_id(
-                    metadata["chat_id"],
-                    metadata["message_id"],
-                    {
-                        "model": model_id,
-                    },
-                )
-            except:
-                pass
+    async def process_chat(request, form_data, user, metadata, model):
+        try:
+            form_data, metadata, events = await process_chat_payload(
+                request, form_data, user, metadata, model
+            )
 
-        return await process_chat_response(
-            request, response, form_data, user, metadata, model, events, tasks
-        )
-    except Exception as e:
-        log.debug(f"Error in chat completion: {e}")
-        if metadata.get("chat_id") and metadata.get("message_id"):
-            # Update the chat message with the error
+            response = await chat_completion_handler(request, form_data, user)
+            if metadata.get("chat_id") and metadata.get("message_id"):
+                try:
+                    Chats.upsert_message_to_chat_by_id_and_message_id(
+                        metadata["chat_id"],
+                        metadata["message_id"],
+                        {
+                            "model": model_id,
+                        },
+                    )
+                except:
+                    pass
+
+            return await process_chat_response(
+                request, response, form_data, user, metadata, model, events, tasks
+            )
+        except asyncio.CancelledError:
+            log.info("Chat processing was cancelled")
             try:
-                Chats.upsert_message_to_chat_by_id_and_message_id(
-                    metadata["chat_id"],
-                    metadata["message_id"],
-                    {
-                        "error": {"content": str(e)},
-                    },
+                event_emitter = get_event_emitter(metadata)
+                await event_emitter(
+                    {"type": "task-cancelled"},
                 )
-            except:
+            except Exception as e:
                 pass
+        except Exception as e:
+            log.debug(f"Error processing chat payload: {e}")
+            if metadata.get("chat_id") and metadata.get("message_id"):
+                # Update the chat message with the error
+                try:
+                    Chats.upsert_message_to_chat_by_id_and_message_id(
+                        metadata["chat_id"],
+                        metadata["message_id"],
+                        {
+                            "error": {"content": str(e)},
+                        },
+                    )
+                except:
+                    pass
 
-        raise HTTPException(
-            status_code=status.HTTP_400_BAD_REQUEST,
-            detail=str(e),
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=str(e),
+            )
+
+    if (
+        metadata.get("session_id")
+        and metadata.get("chat_id")
+        and metadata.get("message_id")
+    ):
+        # Asynchronous Chat Processing
+        task_id, _ = await create_task(
+            request.app.state.redis,
+            process_chat(request, form_data, user, metadata, model),
+            id=metadata["chat_id"],
         )
+        return {"status": True, "task_id": task_id}
+    else:
+        return await process_chat(request, form_data, user, metadata, model)
 
 
 # Alias for chat_completion (Legacy)

+ 1 - 8
backend/open_webui/utils/middleware.py

@@ -86,7 +86,6 @@ from open_webui.utils.filter import (
 from open_webui.utils.code_interpreter import execute_code_jupyter
 from open_webui.utils.payload import apply_model_system_prompt_to_body
 
-from open_webui.tasks import create_task
 
 from open_webui.config import (
     CACHE_DIR,
@@ -2600,13 +2599,7 @@ async def process_chat_response(
             if response.background is not None:
                 await response.background()
 
-        # background_tasks.add_task(response_handler, response, events)
-        task_id, _ = await create_task(
-            request.app.state.redis,
-            response_handler(response, events),
-            id=metadata["chat_id"],
-        )
-        return {"status": True, "task_id": task_id}
+        return await response_handler(response, events)
 
     else:
         # Fallback to the original response