Ver código fonte

refac/fix: temp chat

Timothy Jaeryang Baek 5 dias atrás
pai
commit
3a601e0fc3

+ 17 - 15
backend/open_webui/main.py

@@ -1495,7 +1495,7 @@ async def chat_completion(
         }
 
         if metadata.get("chat_id") and (user and user.role != "admin"):
-            if metadata["chat_id"] != "local":
+            if not metadata["chat_id"].startswith("local:"):
                 chat = Chats.get_chat_by_id_and_user_id(metadata["chat_id"], user.id)
                 if chat is None:
                     raise HTTPException(
@@ -1522,13 +1522,14 @@ async def chat_completion(
             response = await chat_completion_handler(request, form_data, user)
             if metadata.get("chat_id") and metadata.get("message_id"):
                 try:
-                    Chats.upsert_message_to_chat_by_id_and_message_id(
-                        metadata["chat_id"],
-                        metadata["message_id"],
-                        {
-                            "model": model_id,
-                        },
-                    )
+                    if not metadata["chat_id"].startswith("local:"):
+                        Chats.upsert_message_to_chat_by_id_and_message_id(
+                            metadata["chat_id"],
+                            metadata["message_id"],
+                            {
+                                "model": model_id,
+                            },
+                        )
                 except:
                     pass
 
@@ -1549,13 +1550,14 @@ async def chat_completion(
             if metadata.get("chat_id") and metadata.get("message_id"):
                 # Update the chat message with the error
                 try:
-                    Chats.upsert_message_to_chat_by_id_and_message_id(
-                        metadata["chat_id"],
-                        metadata["message_id"],
-                        {
-                            "error": {"content": str(e)},
-                        },
-                    )
+                    if not metadata["chat_id"].startswith("local:"):
+                        Chats.upsert_message_to_chat_by_id_and_message_id(
+                            metadata["chat_id"],
+                            metadata["message_id"],
+                            {
+                                "error": {"content": str(e)},
+                            },
+                        )
 
                     event_emitter = get_event_emitter(metadata)
                     await event_emitter(

+ 10 - 4
backend/open_webui/socket/main.py

@@ -653,12 +653,15 @@ def get_event_emitter(request_info, update_db=True):
             )
         )
 
+        chat_id = request_info.get("chat_id", None)
+        message_id = request_info.get("message_id", None)
+
         emit_tasks = [
             sio.emit(
                 "chat-events",
                 {
-                    "chat_id": request_info.get("chat_id", None),
-                    "message_id": request_info.get("message_id", None),
+                    "chat_id": chat_id,
+                    "message_id": message_id,
                     "data": event_data,
                 },
                 to=session_id,
@@ -667,8 +670,11 @@ def get_event_emitter(request_info, update_db=True):
         ]
 
         await asyncio.gather(*emit_tasks)
-
-        if update_db:
+        if (
+            update_db
+            and message_id
+            and not request_info.get("chat_id", "").startswith("local:")
+        ):
             if "type" in event_data and event_data["type"] == "status":
                 Chats.add_message_status_to_chat_by_id_and_message_id(
                     request_info["chat_id"],

+ 114 - 91
backend/open_webui/utils/middleware.py

@@ -80,6 +80,7 @@ from open_webui.utils.misc import (
     add_or_update_system_message,
     add_or_update_user_message,
     get_last_user_message,
+    get_last_user_message_item,
     get_last_assistant_message,
     get_system_message,
     prepend_to_first_user_message_content,
@@ -1418,10 +1419,13 @@ async def process_chat_response(
     request, response, form_data, user, metadata, model, events, tasks
 ):
     async def background_tasks_handler():
-        messages_map = Chats.get_messages_map_by_chat_id(metadata["chat_id"])
-        message = messages_map.get(metadata["message_id"]) if messages_map else None
+        message = None
+        messages = []
+
+        if "chat_id" in metadata and not metadata["chat_id"].startswith("local:"):
+            messages_map = Chats.get_messages_map_by_chat_id(metadata["chat_id"])
+            message = messages_map.get(metadata["message_id"]) if messages_map else None
 
-        if message:
             message_list = get_message_list(messages_map, metadata["message_id"])
 
             # Remove details tags and files from the messages.
@@ -1454,12 +1458,21 @@ async def process_chat_response(
                         "content": content,
                     }
                 )
+        else:
+            # Local temp chat, get the model and message from the form_data
+            message = get_last_user_message_item(form_data.get("messages", []))
+            messages = form_data.get("messages", [])
+            if message:
+                message["model"] = form_data.get("model")
 
+        if message and "model" in message:
             if tasks and messages:
                 if (
                     TASKS.FOLLOW_UP_GENERATION in tasks
                     and tasks[TASKS.FOLLOW_UP_GENERATION]
                 ):
+
+                    print("Generating follow ups")
                     res = await generate_follow_ups(
                         request,
                         {
@@ -1490,15 +1503,6 @@ async def process_chat_response(
                             follow_ups = json.loads(follow_ups_string).get(
                                 "follow_ups", []
                             )
-
-                            Chats.upsert_message_to_chat_by_id_and_message_id(
-                                metadata["chat_id"],
-                                metadata["message_id"],
-                                {
-                                    "followUps": follow_ups,
-                                },
-                            )
-
                             await event_emitter(
                                 {
                                     "type": "chat:message:follow_ups",
@@ -1507,111 +1511,130 @@ async def process_chat_response(
                                     },
                                 }
                             )
+
+                            if not metadata.get("chat_id", "").startswith("local:"):
+                                Chats.upsert_message_to_chat_by_id_and_message_id(
+                                    metadata["chat_id"],
+                                    metadata["message_id"],
+                                    {
+                                        "followUps": follow_ups,
+                                    },
+                                )
+
                         except Exception as e:
                             pass
 
-                if TASKS.TITLE_GENERATION in tasks:
-                    user_message = get_last_user_message(messages)
-                    if user_message and len(user_message) > 100:
-                        user_message = user_message[:100] + "..."
+                if not metadata.get("chat_id", "").startswith(
+                    "local:"
+                ):  # Only update titles and tags for non-temp chats
+                    if (
+                        TASKS.TITLE_GENERATION in tasks
+                        and tasks[TASKS.TITLE_GENERATION]
+                    ):
+                        user_message = get_last_user_message(messages)
+                        if user_message and len(user_message) > 100:
+                            user_message = user_message[:100] + "..."
 
-                    if tasks[TASKS.TITLE_GENERATION]:
+                        if tasks[TASKS.TITLE_GENERATION]:
 
-                        res = await generate_title(
-                            request,
-                            {
-                                "model": message["model"],
-                                "messages": messages,
-                                "chat_id": metadata["chat_id"],
-                            },
-                            user,
-                        )
+                            res = await generate_title(
+                                request,
+                                {
+                                    "model": message["model"],
+                                    "messages": messages,
+                                    "chat_id": metadata["chat_id"],
+                                },
+                                user,
+                            )
 
-                        if res and isinstance(res, dict):
-                            if len(res.get("choices", [])) == 1:
-                                title_string = (
-                                    res.get("choices", [])[0]
-                                    .get("message", {})
-                                    .get(
-                                        "content", message.get("content", user_message)
+                            if res and isinstance(res, dict):
+                                if len(res.get("choices", [])) == 1:
+                                    title_string = (
+                                        res.get("choices", [])[0]
+                                        .get("message", {})
+                                        .get(
+                                            "content",
+                                            message.get("content", user_message),
+                                        )
                                     )
-                                )
-                            else:
-                                title_string = ""
+                                else:
+                                    title_string = ""
 
-                            title_string = title_string[
-                                title_string.find("{") : title_string.rfind("}") + 1
-                            ]
+                                title_string = title_string[
+                                    title_string.find("{") : title_string.rfind("}") + 1
+                                ]
 
-                            try:
-                                title = json.loads(title_string).get(
-                                    "title", user_message
+                                try:
+                                    title = json.loads(title_string).get(
+                                        "title", user_message
+                                    )
+                                except Exception as e:
+                                    title = ""
+
+                                if not title:
+                                    title = messages[0].get("content", user_message)
+
+                                Chats.update_chat_title_by_id(
+                                    metadata["chat_id"], title
                                 )
-                            except Exception as e:
-                                title = ""
 
-                            if not title:
-                                title = messages[0].get("content", user_message)
+                                await event_emitter(
+                                    {
+                                        "type": "chat:title",
+                                        "data": title,
+                                    }
+                                )
+                        elif len(messages) == 2:
+                            title = messages[0].get("content", user_message)
 
                             Chats.update_chat_title_by_id(metadata["chat_id"], title)
 
                             await event_emitter(
                                 {
                                     "type": "chat:title",
-                                    "data": title,
+                                    "data": message.get("content", user_message),
                                 }
                             )
-                    elif len(messages) == 2:
-                        title = messages[0].get("content", user_message)
 
-                        Chats.update_chat_title_by_id(metadata["chat_id"], title)
-
-                        await event_emitter(
+                    if TASKS.TAGS_GENERATION in tasks and tasks[TASKS.TAGS_GENERATION]:
+                        res = await generate_chat_tags(
+                            request,
                             {
-                                "type": "chat:title",
-                                "data": message.get("content", user_message),
-                            }
+                                "model": message["model"],
+                                "messages": messages,
+                                "chat_id": metadata["chat_id"],
+                            },
+                            user,
                         )
 
-                if TASKS.TAGS_GENERATION in tasks and tasks[TASKS.TAGS_GENERATION]:
-                    res = await generate_chat_tags(
-                        request,
-                        {
-                            "model": message["model"],
-                            "messages": messages,
-                            "chat_id": metadata["chat_id"],
-                        },
-                        user,
-                    )
-
-                    if res and isinstance(res, dict):
-                        if len(res.get("choices", [])) == 1:
-                            tags_string = (
-                                res.get("choices", [])[0]
-                                .get("message", {})
-                                .get("content", "")
-                            )
-                        else:
-                            tags_string = ""
+                        if res and isinstance(res, dict):
+                            if len(res.get("choices", [])) == 1:
+                                tags_string = (
+                                    res.get("choices", [])[0]
+                                    .get("message", {})
+                                    .get("content", "")
+                                )
+                            else:
+                                tags_string = ""
 
-                        tags_string = tags_string[
-                            tags_string.find("{") : tags_string.rfind("}") + 1
-                        ]
+                            tags_string = tags_string[
+                                tags_string.find("{") : tags_string.rfind("}") + 1
+                            ]
 
-                        try:
-                            tags = json.loads(tags_string).get("tags", [])
-                            Chats.update_chat_tags_by_id(
-                                metadata["chat_id"], tags, user
-                            )
+                            try:
+                                tags = json.loads(tags_string).get("tags", [])
+                                Chats.update_chat_tags_by_id(
+                                    metadata["chat_id"], tags, user
+                                )
 
-                            await event_emitter(
-                                {
-                                    "type": "chat:tags",
-                                    "data": tags,
-                                }
-                            )
-                        except Exception as e:
-                            pass
+                                await event_emitter(
+                                    {
+                                        "type": "chat:tags",
+                                        "data": tags,
+                                    }
+                                )
+                            except Exception as e:
+                                pass
 
     event_emitter = None
     event_caller = None

+ 2 - 2
src/lib/components/chat/Chat.svelte

@@ -2207,8 +2207,8 @@
 
 			selectedFolder.set(null);
 		} else {
-			_chatId = 'local';
-			await chatId.set('local');
+			_chatId = `local:${$socket?.id}`; // Use socket id for temporary chat
+			await chatId.set(_chatId);
 		}
 		await tick();
 

+ 1 - 1
src/lib/components/chat/Navbar.svelte

@@ -248,7 +248,7 @@
 		</div>
 	</div>
 
-	{#if $temporaryChatEnabled && $chatId === 'local'}
+	{#if $temporaryChatEnabled && ($chatId ?? '').startsWith('local:')}
 		<div class=" w-full z-30 text-center">
 			<div class="text-xs text-gray-500">{$i18n.t('Temporary Chat')}</div>
 		</div>

+ 1 - 1
src/lib/components/layout/Navbar/Menu.svelte

@@ -232,7 +232,7 @@
 		if (chat.id) {
 			let chatObj = null;
 
-			if (chat.id === 'local' || $temporaryChatEnabled) {
+			if ((chat?.id ?? '').startsWith('local') || $temporaryChatEnabled) {
 				chatObj = chat;
 			} else {
 				chatObj = await getChatById(localStorage.token, chat.id);