Browse Source

refac: rag context handling

Timothy Jaeryang Baek 2 tuần trước cách đây
mục cha
commit
f096e99059
2 tập tin đã thay đổi với 36 bổ sung39 xóa
  1. 9 20
      backend/open_webui/utils/middleware.py
  2. 27 19
      backend/open_webui/utils/misc.py

+ 9 - 20
backend/open_webui/utils/middleware.py

@@ -1171,26 +1171,15 @@ async def process_chat_payload(request, form_data, user, metadata, model):
             raise Exception("No user message found")
 
         if context_string != "":
-            # Workaround for Ollama 2.0+ system prompt issue
-            # TODO: replace with add_or_update_system_message
-            if model.get("owned_by") == "ollama":
-                form_data["messages"] = prepend_to_first_user_message_content(
-                    rag_template(
-                        request.app.state.config.RAG_TEMPLATE,
-                        context_string,
-                        prompt,
-                    ),
-                    form_data["messages"],
-                )
-            else:
-                form_data["messages"] = add_or_update_system_message(
-                    rag_template(
-                        request.app.state.config.RAG_TEMPLATE,
-                        context_string,
-                        prompt,
-                    ),
-                    form_data["messages"],
-                )
+            form_data["messages"] = add_or_update_user_message(
+                rag_template(
+                    request.app.state.config.RAG_TEMPLATE,
+                    context_string,
+                    prompt,
+                ),
+                form_data["messages"],
+                append=False,
+            )
 
     # If there are citations, add them to the data_items
     sources = [

+ 27 - 19
backend/open_webui/utils/misc.py

@@ -120,19 +120,20 @@ def pop_system_message(messages: list[dict]) -> tuple[Optional[dict], list[dict]
     return get_system_message(messages), remove_system_message(messages)
 
 
-def prepend_to_first_user_message_content(
-    content: str, messages: list[dict]
-) -> list[dict]:
-    for message in messages:
-        if message["role"] == "user":
-            if isinstance(message["content"], list):
-                for item in message["content"]:
-                    if item["type"] == "text":
-                        item["text"] = f"{content}\n{item['text']}"
-            else:
-                message["content"] = f"{content}\n{message['content']}"
-            break
-    return messages
+def update_message_content(message: dict, content: str, append: bool = True) -> dict:
+    if isinstance(message["content"], list):
+        for item in message["content"]:
+            if item["type"] == "text":
+                if append:
+                    item["text"] = f"{item['text']}\n{content}"
+                else:
+                    item["text"] = f"{content}\n{item['text']}"
+    else:
+        if append:
+            message["content"] = f"{message['content']}\n{content}"
+        else:
+            message["content"] = f"{content}\n{message['content']}"
+    return message
 
 
 def add_or_update_system_message(
@@ -148,10 +149,7 @@ def add_or_update_system_message(
     """
 
     if messages and messages[0].get("role") == "system":
-        if append:
-            messages[0]["content"] = f"{messages[0]['content']}\n{content}"
-        else:
-            messages[0]["content"] = f"{content}\n{messages[0]['content']}"
+        messages[0] = update_message_content(messages[0], content, append)
     else:
         # Insert at the beginning
         messages.insert(0, {"role": "system", "content": content})
@@ -159,7 +157,7 @@ def add_or_update_system_message(
     return messages
 
 
-def add_or_update_user_message(content: str, messages: list[dict]):
+def add_or_update_user_message(content: str, messages: list[dict], append: bool = True):
     """
     Adds a new user message at the end of the messages list
     or updates the existing user message at the end.
@@ -170,7 +168,7 @@ def add_or_update_user_message(content: str, messages: list[dict]):
     """
 
     if messages and messages[-1].get("role") == "user":
-        messages[-1]["content"] = f"{messages[-1]['content']}\n{content}"
+        messages[-1] = update_message_content(messages[-1], content, append)
     else:
         # Insert at the end
         messages.append({"role": "user", "content": content})
@@ -178,6 +176,16 @@ def add_or_update_user_message(content: str, messages: list[dict]):
     return messages
 
 
+def prepend_to_first_user_message_content(
+    content: str, messages: list[dict]
+) -> list[dict]:
+    for message in messages:
+        if message["role"] == "user":
+            message = update_message_content(message, content, append=False)
+            break
+    return messages
+
+
 def append_or_update_assistant_message(content: str, messages: list[dict]):
     """
     Adds a new assistant message at the end of the messages list