Parcourir la source

feat: ref chat

Timothy Jaeryang Baek il y a 3 semaines
Parent
commit
aa8ab349ed

+ 1 - 1
backend/open_webui/models/chats.py

@@ -236,7 +236,7 @@ class ChatTable:
 
         return chat.chat.get("title", "New Chat")
 
-    def get_messages_by_chat_id(self, id: str) -> Optional[dict]:
+    def get_messages_map_by_chat_id(self, id: str) -> Optional[dict]:
         chat = self.get_chat_by_id(id)
         if chat is None:
             return None

+ 30 - 0
backend/open_webui/retrieval/utils.py

@@ -19,10 +19,13 @@ from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
 from open_webui.models.users import UserModel
 from open_webui.models.files import Files
 from open_webui.models.knowledge import Knowledges
+
+from open_webui.models.chats import Chats
 from open_webui.models.notes import Notes
 
 from open_webui.retrieval.vector.main import GetResult
 from open_webui.utils.access_control import has_access
+from open_webui.utils.misc import get_message_list
 
 
 from open_webui.env import (
@@ -538,6 +541,33 @@ def get_sources_from_items(
                     "metadatas": [[{"file_id": note.id, "name": note.title}]],
                 }
 
+        elif item.get("type") == "chat":
+            # Chat Attached
+            chat = Chats.get_chat_by_id(item.get("id"))
+            print("chat", chat)
+
+            if chat and (user.role == "admin" or chat.user_id == user.id):
+                messages_map = chat.chat.get("history", {}).get("messages", {})
+                message_id = chat.chat.get("history", {}).get("currentId")
+
+                print(messages_map, message_id)
+
+                if messages_map and message_id:
+                    # Reconstruct the message list in order
+                    message_list = get_message_list(messages_map, message_id)
+                    message_history = "\n".join(
+                        [
+                            f"{m.get('role', 'user').capitalize()}: {m.get('content')}"
+                            for m in message_list
+                        ]
+                    )
+
+                    # User has access to the chat
+                    query_result = {
+                        "documents": [[message_history]],
+                        "metadatas": [[{"file_id": chat.id, "name": chat.title}]],
+                    }
+
         elif item.get("type") == "file":
             if (
                 item.get("context") == "full"

+ 3 - 3
backend/open_webui/utils/middleware.py

@@ -1131,11 +1131,11 @@ async def process_chat_response(
     request, response, form_data, user, metadata, model, events, tasks
 ):
     async def background_tasks_handler():
-        message_map = Chats.get_messages_by_chat_id(metadata["chat_id"])
-        message = message_map.get(metadata["message_id"]) if message_map else None
+        messages_map = Chats.get_messages_map_by_chat_id(metadata["chat_id"])
+        message = messages_map.get(metadata["message_id"]) if messages_map else None
 
         if message:
-            message_list = get_message_list(message_map, metadata["message_id"])
+            message_list = get_message_list(messages_map, metadata["message_id"])
 
             # Remove details tags and files from the messages.
             # as get_message_list creates a new list, it does not affect

+ 4 - 4
backend/open_webui/utils/misc.py

@@ -26,7 +26,7 @@ def deep_update(d, u):
     return d
 
 
-def get_message_list(messages, message_id):
+def get_message_list(messages_map, message_id):
     """
     Reconstructs a list of messages in order up to the specified message_id.
 
@@ -36,11 +36,11 @@ def get_message_list(messages, message_id):
     """
 
     # Handle case where messages is None
-    if not messages:
+    if not messages_map:
         return []  # Return empty list instead of None to prevent iteration errors
 
     # Find the message by its id
-    current_message = messages.get(message_id)
+    current_message = messages_map.get(message_id)
 
     if not current_message:
         return []  # Return empty list instead of None to prevent iteration errors
@@ -53,7 +53,7 @@ def get_message_list(messages, message_id):
             0, current_message
         )  # Insert the message at the beginning of the list
         parent_id = current_message.get("parentId")  # Use .get() for safety
-        current_message = messages.get(parent_id) if parent_id else None
+        current_message = messages_map.get(parent_id) if parent_id else None
 
     return message_list
 

+ 1 - 1
src/lib/components/chat/Chat.svelte

@@ -1715,7 +1715,7 @@
 		let files = JSON.parse(JSON.stringify(chatFiles));
 		files.push(
 			...(userMessage?.files ?? []).filter((item) =>
-				['doc', 'text', 'file', 'note', 'collection'].includes(item.type)
+				['doc', 'text', 'file', 'note', 'chat', 'collection'].includes(item.type)
 			)
 		);
 		// Remove duplicates