Prechádzať zdrojové kódy

refac/fix: system prompt duplication

Timothy Jaeryang Baek 5 dní pred
rodič
commit
a1fc99c66f

+ 1 - 1
backend/open_webui/utils/middleware.py

@@ -1004,7 +1004,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
     if system_message:
         try:
             form_data = apply_system_prompt_to_body(
-                system_message.get("content"), form_data, metadata, user
+                system_message.get("content"), form_data, metadata, user, replace=True
             )
         except:
             pass

+ 8 - 0
backend/open_webui/utils/misc.py

@@ -136,6 +136,14 @@ def update_message_content(message: dict, content: str, append: bool = True) ->
     return message
 
 
+def replace_system_message_content(content: str, messages: list[dict]) -> dict:
+    for message in messages:
+        if message["role"] == "system":
+            message["content"] = content
+            break
+    return messages
+
+
 def add_or_update_system_message(
     content: str, messages: list[dict], append: bool = False
 ):

+ 15 - 4
backend/open_webui/utils/payload.py

@@ -2,6 +2,7 @@ from open_webui.utils.task import prompt_template, prompt_variables_template
 from open_webui.utils.misc import (
     deep_update,
     add_or_update_system_message,
+    replace_system_message_content,
 )
 
 from typing import Callable, Optional
@@ -10,7 +11,11 @@ import json
 
 # inplace function: form_data is modified
 def apply_system_prompt_to_body(
-    system: Optional[str], form_data: dict, metadata: Optional[dict] = None, user=None
+    system: Optional[str],
+    form_data: dict,
+    metadata: Optional[dict] = None,
+    user=None,
+    replace: bool = False,
 ) -> dict:
     if not system:
         return form_data
@@ -24,9 +29,15 @@ def apply_system_prompt_to_body(
     # Legacy (API Usage)
     system = prompt_template(system, user)
 
-    form_data["messages"] = add_or_update_system_message(
-        system, form_data.get("messages", [])
-    )
+    if replace:
+        form_data["messages"] = replace_system_message_content(
+            system, form_data.get("messages", [])
+        )
+    else:
+        form_data["messages"] = add_or_update_system_message(
+            system, form_data.get("messages", [])
+        )
+
     return form_data