Explorar o código

enh: custom reasoning tags

Timothy Jaeryang Baek hai 1 mes
pai
achega
e39ce16a86

+ 5 - 0
backend/open_webui/main.py

@@ -1437,11 +1437,15 @@ async def chat_completion(
         stream_delta_chunk_size = form_data.get("params", {}).get(
             "stream_delta_chunk_size"
         )
+        reasoning_tags = form_data.get("params", {}).get("reasoning_tags")
 
         # Model Params
         if model_info_params.get("stream_delta_chunk_size"):
             stream_delta_chunk_size = model_info_params.get("stream_delta_chunk_size")
 
+        if model_info_params.get("reasoning_tags") is not None:
+            reasoning_tags = model_info_params.get("reasoning_tags")
+
         metadata = {
             "user_id": user.id,
             "chat_id": form_data.pop("chat_id", None),
@@ -1457,6 +1461,7 @@ async def chat_completion(
             "direct": model_item.get("direct", False),
             "params": {
                 "stream_delta_chunk_size": stream_delta_chunk_size,
+                "reasoning_tags": reasoning_tags,
                 "function_calling": (
                     "native"
                     if (

+ 39 - 29
backend/open_webui/utils/middleware.py

@@ -111,6 +111,20 @@ log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MAIN"])
 
 
+DEFAULT_REASONING_TAGS = [
+    ("<think>", "</think>"),
+    ("<thinking>", "</thinking>"),
+    ("<reason>", "</reason>"),
+    ("<reasoning>", "</reasoning>"),
+    ("<thought>", "</thought>"),
+    ("<Thought>", "</Thought>"),
+    ("<|begin_of_thought|>", "<|end_of_thought|>"),
+    ("◁think▷", "◁/think▷"),
+]
+DEFAULT_SOLUTION_TAGS = [("<|begin_of_solution|>", "<|end_of_solution|>")]
+DEFAULT_CODE_INTERPRETER_TAGS = [("<code_interpreter>", "</code_interpreter>")]
+
+
 async def chat_completion_tools_handler(
     request: Request, body: dict, extra_params: dict, user: UserModel, models, tools
 ) -> tuple[dict, dict]:
@@ -694,6 +708,7 @@ def apply_params_to_form_data(form_data, model):
         "stream_response": bool,
         "stream_delta_chunk_size": int,
         "function_calling": str,
+        "reasoning_tags": list,
         "system": str,
     }
 
@@ -1811,27 +1826,23 @@ async def process_chat_response(
                 }
             ]
 
-            # We might want to disable this by default
-            DETECT_REASONING = True
-            DETECT_SOLUTION = True
+            reasoning_tags_param = metadata.get("params", {}).get("reasoning_tags")
+            DETECT_REASONING_TAGS = reasoning_tags_param is not False
             DETECT_CODE_INTERPRETER = metadata.get("features", {}).get(
                 "code_interpreter", False
             )
 
-            reasoning_tags = [
-                ("<think>", "</think>"),
-                ("<thinking>", "</thinking>"),
-                ("<reason>", "</reason>"),
-                ("<reasoning>", "</reasoning>"),
-                ("<thought>", "</thought>"),
-                ("<Thought>", "</Thought>"),
-                ("<|begin_of_thought|>", "<|end_of_thought|>"),
-                ("◁think▷", "◁/think▷"),
-            ]
-
-            code_interpreter_tags = [("<code_interpreter>", "</code_interpreter>")]
-
-            solution_tags = [("<|begin_of_solution|>", "<|end_of_solution|>")]
+            reasoning_tags = []
+            if DETECT_REASONING_TAGS:
+                if (
+                    isinstance(reasoning_tags_param, list)
+                    and len(reasoning_tags_param) == 2
+                ):
+                    reasoning_tags = [
+                        (reasoning_tags_param[0], reasoning_tags_param[1])
+                    ]
+                else:
+                    reasoning_tags = DEFAULT_REASONING_TAGS
 
             try:
                 for event in events:
@@ -2083,7 +2094,7 @@ async def process_chat_response(
                                             content_blocks[-1]["content"] + value
                                         )
 
-                                        if DETECT_REASONING:
+                                        if DETECT_REASONING_TAGS:
                                             content, content_blocks, _ = (
                                                 tag_content_handler(
                                                     "reasoning",
@@ -2093,29 +2104,28 @@ async def process_chat_response(
                                                 )
                                             )
 
-                                        if DETECT_CODE_INTERPRETER:
-                                            content, content_blocks, end = (
+                                            content, content_blocks, _ = (
                                                 tag_content_handler(
-                                                    "code_interpreter",
-                                                    code_interpreter_tags,
+                                                    "solution",
+                                                    DEFAULT_SOLUTION_TAGS,
                                                     content,
                                                     content_blocks,
                                                 )
                                             )
 
-                                            if end:
-                                                break
-
-                                        if DETECT_SOLUTION:
-                                            content, content_blocks, _ = (
+                                        if DETECT_CODE_INTERPRETER:
+                                            content, content_blocks, end = (
                                                 tag_content_handler(
-                                                    "solution",
-                                                    solution_tags,
+                                                    "code_interpreter",
+                                                    DEFAULT_CODE_INTERPRETER_TAGS,
                                                     content,
                                                     content_blocks,
                                                 )
                                             )
 
+                                            if end:
+                                                break
+
                                         if ENABLE_REALTIME_CHAT_SAVE:
                                             # Save message in the database
                                             Chats.upsert_message_to_chat_by_id_and_message_id(

+ 1 - 0
backend/open_webui/utils/payload.py

@@ -63,6 +63,7 @@ def remove_open_webui_params(params: dict) -> dict:
         "stream_response": bool,
         "stream_delta_chunk_size": int,
         "function_calling": str,
+        "reasoning_tags": list,
         "system": str,
     }
 

+ 64 - 0
src/lib/components/chat/Settings/Advanced/AdvancedParams.svelte

@@ -17,6 +17,7 @@
 		stream_response: null, // Set stream responses for this model individually
 		stream_delta_chunk_size: null, // Set the chunk size for streaming responses
 		function_calling: null,
+		reasoning_tags: null,
 		seed: null,
 		stop: null,
 		temperature: null,
@@ -175,6 +176,69 @@
 		</Tooltip>
 	</div>
 
+	<div class=" py-0.5 w-full justify-between">
+		<Tooltip
+			content={$i18n.t('Custom reasoning tags to use for the model.')}
+			placement="top-start"
+			className="inline-tooltip"
+		>
+			<div class="flex w-full justify-between">
+				<div class=" self-center text-xs font-medium">
+					{$i18n.t('Reasoning Tags')}
+				</div>
+				<button
+					class="p-1 px-3 text-xs flex rounded-sm transition shrink-0 outline-hidden"
+					type="button"
+					on:click={() => {
+						if ((params?.reasoning_tags ?? null) === null) {
+							params.reasoning_tags = ['', ''];
+						} else if ((params?.reasoning_tags ?? []).length === 2) {
+							params.reasoning_tags = true;
+						} else if ((params?.reasoning_tags ?? null) !== false) {
+							params.reasoning_tags = false;
+						} else {
+							params.reasoning_tags = null;
+						}
+					}}
+				>
+					{#if (params?.reasoning_tags ?? null) === null}
+						<span class="ml-2 self-center"> {$i18n.t('Default')} </span>
+					{:else if (params?.reasoning_tags ?? null) === true}
+						<span class="ml-2 self-center"> {$i18n.t('Enabled')} </span>
+					{:else if (params?.reasoning_tags ?? null) === false}
+						<span class="ml-2 self-center"> {$i18n.t('Disabled')} </span>
+					{:else}
+						<span class="ml-2 self-center"> {$i18n.t('Custom')} </span>
+					{/if}
+				</button>
+			</div>
+		</Tooltip>
+
+		{#if ![true, false, null].includes(params?.reasoning_tags ?? null) && (params?.reasoning_tags ?? []).length === 2}
+			<div class="flex mt-0.5 space-x-2">
+				<div class=" flex-1">
+					<input
+						class="text-sm w-full bg-transparent outline-hidden outline-none"
+						type="text"
+						placeholder={$i18n.t('Start Tag')}
+						bind:value={params.reasoning_tags[0]}
+						autocomplete="off"
+					/>
+				</div>
+
+				<div class=" flex-1">
+					<input
+						class="text-sm w-full bg-transparent outline-hidden outline-none"
+						type="text"
+						placeholder={$i18n.t('End Tag')}
+						bind:value={params.reasoning_tags[1]}
+						autocomplete="off"
+					/>
+				</div>
+			</div>
+		{/if}
+	</div>
+
 	<div class=" py-0.5 w-full justify-between">
 		<Tooltip
 			content={$i18n.t(