Kaynağa Gözat

refac: memories

Timothy Jaeryang Baek 4 ay önce
ebeveyn
işleme
a2f12db8d9

+ 38 - 0
backend/open_webui/utils/middleware.py

@@ -41,6 +41,7 @@ from open_webui.routers.pipelines import (
     process_pipeline_inlet_filter,
     process_pipeline_outlet_filter,
 )
+from open_webui.routers.memories import query_memory, QueryMemoryForm
 
 from open_webui.utils.webhook import post_webhook
 
@@ -290,6 +291,38 @@ async def chat_completion_tools_handler(
     return body, {"sources": sources}
 
 
+async def chat_memory_handler(
+    request: Request, form_data: dict, extra_params: dict, user
+):
+    results = await query_memory(
+        request,
+        QueryMemoryForm(
+            **{"content": get_last_user_message(form_data["messages"]), "k": 3}
+        ),
+        user,
+    )
+
+    user_context = ""
+    if results and hasattr(results, "documents"):
+        if results.documents and len(results.documents) > 0:
+            for doc_idx, doc in enumerate(results.documents[0]):
+                created_at_date = "Unknown Date"
+
+                if results.metadatas[0][doc_idx].get("created_at"):
+                    created_at_timestamp = results.metadatas[0][doc_idx]["created_at"]
+                    created_at_date = time.strftime(
+                        "%Y-%m-%d", time.localtime(created_at_timestamp)
+                    )
+
+                user_context += f"{doc_idx + 1}. [{created_at_date}] {doc}\n"
+
+    form_data["messages"] = add_or_update_system_message(
+        f"User Context:\n{user_context}\n", form_data["messages"], append=True
+    )
+
+    return form_data
+
+
 async def chat_web_search_handler(
     request: Request, form_data: dict, extra_params: dict, user
 ):
@@ -774,6 +807,11 @@ async def process_chat_payload(request, form_data, user, metadata, model):
 
     features = form_data.pop("features", None)
     if features:
+        if "memory" in features and features["memory"]:
+            form_data = await chat_memory_handler(
+                request, form_data, extra_params, user
+            )
+
         if "web_search" in features and features["web_search"]:
             form_data = await chat_web_search_handler(
                 request, form_data, extra_params, user

+ 7 - 2
backend/open_webui/utils/misc.py

@@ -130,7 +130,9 @@ def prepend_to_first_user_message_content(
     return messages
 
 
-def add_or_update_system_message(content: str, messages: list[dict]):
+def add_or_update_system_message(
+    content: str, messages: list[dict], append: bool = False
+):
     """
     Adds a new system message at the beginning of the messages list
     or updates the existing system message at the beginning.
@@ -141,7 +143,10 @@ def add_or_update_system_message(content: str, messages: list[dict]):
     """
 
     if messages and messages[0].get("role") == "system":
-        messages[0]["content"] = f"{content}\n{messages[0]['content']}"
+        if append:
+            messages[0]["content"] = f"{messages[0]['content']}\n{content}"
+        else:
+            messages[0]["content"] = f"{content}\n{messages[0]['content']}"
     else:
         # Insert at the beginning
         messages.insert(0, {"role": "system", "content": content})

+ 4 - 34
src/lib/components/chat/Chat.svelte

@@ -1431,7 +1431,6 @@
 					model: model.id,
 					modelName: model.name ?? model.id,
 					modelIdx: modelIdx ? modelIdx : _modelIdx,
-					userContext: null,
 					timestamp: Math.floor(Date.now() / 1000) // Unix epoch
 				};
 
@@ -1486,32 +1485,6 @@
 
 					let responseMessageId =
 						responseMessageIds[`${modelId}-${modelIdx ? modelIdx : _modelIdx}`];
-					let responseMessage = _history.messages[responseMessageId];
-
-					let userContext = null;
-					if ($settings?.memory ?? false) {
-						if (userContext === null) {
-							const res = await queryMemory(localStorage.token, prompt).catch((error) => {
-								toast.error(`${error}`);
-								return null;
-							});
-							if (res) {
-								if (res.documents[0].length > 0) {
-									userContext = res.documents[0].reduce((acc, doc, index) => {
-										const createdAtTimestamp = res.metadatas[0][index].created_at;
-										const createdAtDate = new Date(createdAtTimestamp * 1000)
-											.toISOString()
-											.split('T')[0];
-										return `${acc}${index + 1}. [${createdAtDate}]. ${doc}\n`;
-									}, '');
-								}
-
-								console.log(userContext);
-							}
-						}
-					}
-					responseMessage.userContext = userContext;
-
 					const chatEventEmitter = await getChatEventEmitter(model.id, _chatId);
 
 					scrollToBottom();
@@ -1573,7 +1546,7 @@
 			true;
 
 		let messages = [
-			params?.system || $settings.system || (responseMessage?.userContext ?? null)
+			params?.system || $settings.system
 				? {
 						role: 'system',
 						content: `${promptTemplate(
@@ -1585,11 +1558,7 @@
 										return undefined;
 									})
 								: undefined
-						)}${
-							(responseMessage?.userContext ?? null)
-								? `\n\nUser Context:\n${responseMessage?.userContext ?? ''}`
-								: ''
-						}`
+						)}`
 					}
 				: undefined,
 			...createMessagesList(_history, responseMessageId).map((message) => ({
@@ -1666,7 +1635,8 @@
 						$config?.features?.enable_web_search &&
 						($user?.role === 'admin' || $user?.permissions?.features?.web_search)
 							? webSearchEnabled || ($settings?.webSearch ?? false) === 'always'
-							: false
+							: false,
+					memory: $settings?.memory ?? false
 				},
 				variables: {
 					...getPromptVariables(