Timothy Jaeryang Baek 1 tuần trước cách đây
mục cha
commit
91b6483aa9
2 tập tin đã thay đổi với 227 bổ sung278 xóa
  1. 17 0
      backend/open_webui/socket/main.py
  2. 210 278
      backend/open_webui/utils/middleware.py

+ 17 - 0
backend/open_webui/socket/main.py

@@ -705,6 +705,23 @@ def get_event_emitter(request_info, update_db=True):
                     },
                 )
 
+            if "type" in event_data and event_data["type"] == "embeds":
+                message = Chats.get_message_by_id_and_message_id(
+                    request_info["chat_id"],
+                    request_info["message_id"],
+                )
+
+                embeds = event_data.get("data", {}).get("embeds", [])
+                embeds.extend(message.get("embeds", []))
+
+                Chats.upsert_message_to_chat_by_id_and_message_id(
+                    request_info["chat_id"],
+                    request_info["message_id"],
+                    {
+                        "embeds": embeds,
+                    },
+                )
+
             if "type" in event_data and event_data["type"] == "files":
                 message = Chats.get_message_by_id_and_message_id(
                     request_info["chat_id"],

+ 210 - 278
backend/open_webui/utils/middleware.py

@@ -133,6 +133,149 @@ DEFAULT_SOLUTION_TAGS = [("<|begin_of_solution|>", "<|end_of_solution|>")]
 DEFAULT_CODE_INTERPRETER_TAGS = [("<code_interpreter>", "</code_interpreter>")]
 
 
+def process_tool_result(
+    request,
+    tool_function_name,
+    tool_result,
+    tool_type,
+    direct_tool=False,
+    metadata=None,
+    user=None,
+):
+    tool_result_embeds = []
+
+    if isinstance(tool_result, HTMLResponse):
+        content_disposition = tool_result.headers.get("Content-Disposition", "")
+        if "inline" in content_disposition:
+            content = tool_result.body.decode("utf-8")
+            tool_result_embeds.append(content)
+
+            if 200 <= tool_result.status_code < 300:
+                tool_result = {
+                    "status": "success",
+                    "code": "ui_component",
+                    "message": f"{tool_function_name}: Embedded UI result is active and visible to the user.",
+                }
+            elif 400 <= tool_result.status_code < 500:
+                tool_result = {
+                    "status": "error",
+                    "code": "ui_component",
+                    "message": f"{tool_function_name}: Client error {tool_result.status_code} from embedded UI result.",
+                }
+            elif 500 <= tool_result.status_code < 600:
+                tool_result = {
+                    "status": "error",
+                    "code": "ui_component",
+                    "message": f"{tool_function_name}: Server error {tool_result.status_code} from embedded UI result.",
+                }
+            else:
+                tool_result = {
+                    "status": "error",
+                    "code": "ui_component",
+                    "message": f"{tool_function_name}: Unexpected status code {tool_result.status_code} from embedded UI result.",
+                }
+        else:
+            tool_result = tool_result.body.decode("utf-8")
+
+    elif (tool_type == "external" and isinstance(tool_result, tuple)) or (
+        direct_tool and isinstance(tool_result, list) and len(tool_result) == 2
+    ):
+        tool_result, tool_response_headers = tool_result
+
+        try:
+            if not isinstance(tool_response_headers, dict):
+                tool_response_headers = dict(tool_response_headers)
+        except Exception as e:
+            tool_response_headers = {}
+            log.debug(e)
+
+        if tool_response_headers and isinstance(tool_response_headers, dict):
+            content_disposition = tool_response_headers.get(
+                "Content-Disposition",
+                tool_response_headers.get("content-disposition", ""),
+            )
+
+            if "inline" in content_disposition:
+                content_type = tool_response_headers.get(
+                    "Content-Type",
+                    tool_response_headers.get("content-type", ""),
+                )
+                location = tool_response_headers.get(
+                    "Location",
+                    tool_response_headers.get("location", ""),
+                )
+
+                if "text/html" in content_type:
+                    # Display as iframe embed
+                    tool_result_embeds.append(tool_result)
+                    tool_result = {
+                        "status": "success",
+                        "code": "ui_component",
+                        "message": f"{tool_function_name}: Embedded UI result is active and visible to the user.",
+                    }
+                elif location:
+                    tool_result_embeds.append(location)
+                    tool_result = {
+                        "status": "success",
+                        "code": "ui_component",
+                        "message": f"{tool_function_name}: Embedded UI result is active and visible to the user.",
+                    }
+
+    tool_result_files = []
+
+    if isinstance(tool_result, list):
+        if tool_type == "mcp":  # MCP
+            tool_response = []
+            for item in tool_result:
+                if isinstance(item, dict):
+                    if item.get("type") == "text":
+                        text = item.get("text", "")
+                        if isinstance(text, str):
+                            try:
+                                text = json.loads(text)
+                            except json.JSONDecodeError:
+                                pass
+                        tool_response.append(text)
+                    elif item.get("type") in ["image", "audio"]:
+                        file_url = get_file_url_from_base64(
+                            request,
+                            f"data:{item.get('mimeType')};base64,{item.get('data', item.get('blob', ''))}",
+                            {
+                                "chat_id": metadata.get("chat_id", None),
+                                "message_id": metadata.get("message_id", None),
+                                "session_id": metadata.get("session_id", None),
+                                "result": item,
+                            },
+                            user,
+                        )
+
+                        tool_result_files.append(
+                            {
+                                "type": item.get("type", "data"),
+                                "url": file_url,
+                            }
+                        )
+            tool_result = tool_response[0] if len(tool_response) == 1 else tool_response
+        else:  # OpenAPI
+            for item in tool_result:
+                if isinstance(item, str) and item.startswith("data:"):
+                    tool_result_files.append(
+                        {
+                            "type": "data",
+                            "content": item,
+                        }
+                    )
+                    tool_result.remove(item)
+
+    if isinstance(tool_result, list):
+        tool_result = {"results": tool_result}
+
+    if isinstance(tool_result, dict) or isinstance(tool_result, list):
+        tool_result = json.dumps(tool_result, indent=2, ensure_ascii=False)
+
+    return tool_result, tool_result_files, tool_result_embeds
+
+
 async def chat_completion_tools_handler(
     request: Request, body: dict, extra_params: dict, user: UserModel, models, tools
 ) -> tuple[dict, dict]:
@@ -172,6 +315,7 @@ async def chat_completion_tools_handler(
         }
 
     event_caller = extra_params["__event_call__"]
+    event_emitter = extra_params["__event_emitter__"]
     metadata = extra_params["__metadata__"]
 
     task_model_id = get_task_model_id(
@@ -226,8 +370,14 @@ async def chat_completion_tools_handler(
 
                 tool_function_params = tool_call.get("parameters", {})
 
+                tool = None
+                tool_type = ""
+                direct_tool = False
+
                 try:
                     tool = tools[tool_function_name]
+                    tool_type = tool.get("type", "")
+                    direct_tool = tool.get("direct", False)
 
                     spec = tool.get("spec", {})
                     allowed_params = (
@@ -259,106 +409,31 @@ async def chat_completion_tools_handler(
                 except Exception as e:
                     tool_result = str(e)
 
-                tool_result_embeds = []
-                if isinstance(tool_result, HTMLResponse):
-                    content_disposition = tool_result.headers.get(
-                        "Content-Disposition", ""
+                tool_result, tool_result_files, tool_result_embeds = (
+                    process_tool_result(
+                        request,
+                        tool_function_name,
+                        tool_result,
+                        tool_type,
+                        direct_tool,
+                        metadata,
+                        user,
                     )
-                    if "inline" in content_disposition:
-                        content = tool_result.body.decode("utf-8")
-                        tool_result_embeds.append(content)
-
-                        if 200 <= tool_result.status_code < 300:
-                            tool_result = {
-                                "status": "success",
-                                "code": "ui_component",
-                                "message": "Embedded UI result is active and visible to the user.",
-                            }
-                        elif 400 <= tool_result.status_code < 500:
-                            tool_result = {
-                                "status": "error",
-                                "code": "ui_component",
-                                "message": f"Client error {tool_result.status_code} from embedded UI result.",
-                            }
-                        elif 500 <= tool_result.status_code < 600:
-                            tool_result = {
-                                "status": "error",
-                                "code": "ui_component",
-                                "message": f"Server error {tool_result.status_code} from embedded UI result.",
-                            }
-                        else:
-                            tool_result = {
-                                "status": "error",
-                                "code": "ui_component",
-                                "message": f"Unexpected status code {tool_result.status_code} from embedded UI result.",
-                            }
-                    else:
-                        tool_result = tool_result.body.decode("utf-8")
-
-                elif (
-                    tool.get("type") == "external" and isinstance(tool_result, tuple)
-                ) or (
-                    tool.get("direct", True)
-                    and isinstance(tool_result, list)
-                    and len(tool_result) == 2
-                ):
-                    tool_result, tool_response_headers = tool_result
-
-                    try:
-                        if not isinstance(tool_response_headers, dict):
-                            tool_response_headers = dict(tool_response_headers)
-                    except Exception as e:
-                        tool_response_headers = {}
-                        log.debug(e)
+                )
 
-                    if tool_response_headers and isinstance(
-                        tool_response_headers, dict
-                    ):
-                        content_disposition = tool_response_headers.get(
-                            "Content-Disposition",
-                            tool_response_headers.get("content-disposition", ""),
+                if event_emitter:
+                    if tool_result_files:
+                        await event_emitter(
+                            {
+                                "type": "files",
+                                "data": {
+                                    "files": tool_result_files,
+                                },
+                            }
                         )
 
-                        if "inline" in content_disposition:
-                            content_type = tool_response_headers.get(
-                                "Content-Type",
-                                tool_response_headers.get("content-type", ""),
-                            )
-                            location = tool_response_headers.get(
-                                "Location",
-                                tool_response_headers.get("location", ""),
-                            )
-
-                            if "text/html" in content_type:
-                                # Display as iframe embed
-                                tool_result_embeds.append(tool_result)
-                                tool_result = {
-                                    "status": "success",
-                                    "code": "ui_component",
-                                    "message": "Embedded UI result is active and visible to the user.",
-                                }
-                            elif location:
-                                tool_result_embeds.append(location)
-                                tool_result = {
-                                    "status": "success",
-                                    "code": "ui_component",
-                                    "message": "Embedded UI result is active and visible to the user.",
-                                }
-
-                tool_result_files = []
-                if isinstance(tool_result, list):
-                    for item in tool_result:
-                        # check if string
-                        if isinstance(item, str) and item.startswith("data:"):
-                            tool_result_files.append(item)
-                            tool_result.remove(item)
-
-                if isinstance(tool_result, dict) or isinstance(tool_result, list):
-                    tool_result = json.dumps(tool_result, indent=2)
-
-                if tool_result_embeds:
-                    if event_caller:
-                        await event_caller(
+                    if tool_result_embeds:
+                        await event_emitter(
                             {
                                 "type": "embeds",
                                 "data": {
@@ -367,7 +442,13 @@ async def chat_completion_tools_handler(
                             }
                         )
 
-                if isinstance(tool_result, str):
+                print(
+                    f"Tool {tool_function_name} result: {tool_result}",
+                    tool_result_files,
+                    tool_result_embeds,
+                )
+
+                if tool_result:
                     tool = tools[tool_function_name]
                     tool_id = tool.get("tool_id", "")
 
@@ -381,18 +462,19 @@ async def chat_completion_tools_handler(
                     sources.append(
                         {
                             "source": {
-                                "name": (f"TOOL:{tool_name}"),
+                                "name": (f"{tool_name}"),
                             },
-                            "document": [tool_result],
+                            "document": [str(tool_result)],
                             "metadata": [
                                 {
-                                    "source": (f"TOOL:{tool_name}"),
+                                    "source": (f"{tool_name}"),
                                     "parameters": tool_function_params,
                                 }
                             ],
                             "tool_result": True,
                         }
                     )
+
                     # Citation is not enabled for this tool
                     body["messages"] = add_or_update_user_message(
                         f"\nTool `{tool_name}` Output: {tool_result}",
@@ -1267,9 +1349,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
         citation_idx_map = {}
 
         for source in sources:
-            is_tool_result = source.get("tool_result", False)
-
-            if "document" in source and not is_tool_result:
+            if "document" in source:
                 for document_text, document_metadata in zip(
                     source["document"], source["metadata"]
                 ):
@@ -1330,6 +1410,10 @@ async def process_chat_payload(request, form_data, user, metadata, model):
             }
         )
 
+    print("Final form_data:", form_data)
+    print("Final metadata:", metadata)
+    print("Final events:", events)
+
     return form_data, metadata, events
 
 
@@ -2538,7 +2622,9 @@ async def process_chat_response(
 
                         print("tool_call", tool_call)
                         tool_call_id = tool_call.get("id", "")
-                        tool_name = tool_call.get("function", {}).get("name", "")
+                        tool_function_name = tool_call.get("function", {}).get(
+                            "name", ""
+                        )
                         tool_args = tool_call.get("function", {}).get("arguments", "{}")
 
                         tool_function_params = {}
@@ -2568,11 +2654,17 @@ async def process_chat_response(
                         )
 
                         tool_result = None
+                        tool = None
+                        tool_type = None
+                        direct_tool = False
 
-                        if tool_name in tools:
-                            tool = tools[tool_name]
+                        if tool_function_name in tools:
+                            tool = tools[tool_function_name]
                             spec = tool.get("spec", {})
 
+                            tool_type = tool.get("type", "")
+                            direct_tool = tool.get("direct", False)
+
                             try:
                                 allowed_params = (
                                     spec.get("parameters", {})
@@ -2586,13 +2678,13 @@ async def process_chat_response(
                                     if k in allowed_params
                                 }
 
-                                if tool.get("direct", False):
+                                if direct_tool:
                                     tool_result = await event_caller(
                                         {
                                             "type": "execute:tool",
                                             "data": {
                                                 "id": str(uuid4()),
-                                                "name": tool_name,
+                                                "name": tool_function_name,
                                                 "params": tool_function_params,
                                                 "server": tool.get("server", {}),
                                                 "session_id": metadata.get(
@@ -2611,176 +2703,17 @@ async def process_chat_response(
                             except Exception as e:
                                 tool_result = str(e)
 
-                        tool_result_embeds = []
-                        if isinstance(tool_result, HTMLResponse):
-                            content_disposition = tool_result.headers.get(
-                                "Content-Disposition", ""
-                            )
-                            if "inline" in content_disposition:
-                                content = tool_result.body.decode("utf-8")
-                                tool_result_embeds.append(content)
-
-                                if 200 <= tool_result.status_code < 300:
-                                    tool_result = {
-                                        "status": "success",
-                                        "code": "ui_component",
-                                        "message": "Embedded UI result is active and visible to the user.",
-                                    }
-                                elif 400 <= tool_result.status_code < 500:
-                                    tool_result = {
-                                        "status": "error",
-                                        "code": "ui_component",
-                                        "message": f"Client error {tool_result.status_code} from embedded UI result.",
-                                    }
-                                elif 500 <= tool_result.status_code < 600:
-                                    tool_result = {
-                                        "status": "error",
-                                        "code": "ui_component",
-                                        "message": f"Server error {tool_result.status_code} from embedded UI result.",
-                                    }
-                                else:
-                                    tool_result = {
-                                        "status": "error",
-                                        "code": "ui_component",
-                                        "message": f"Unexpected status code {tool_result.status_code} from embedded UI result.",
-                                    }
-                            else:
-                                tool_result = tool_result.body.decode("utf-8")
-
-                        elif (
-                            tool.get("type") == "external"
-                            and isinstance(tool_result, tuple)
-                        ) or (
-                            tool.get("direct", True)
-                            and isinstance(tool_result, list)
-                            and len(tool_result) == 2
-                        ):
-                            tool_result, tool_response_headers = tool_result
-
-                            try:
-                                if not isinstance(tool_response_headers, dict):
-                                    tool_response_headers = dict(tool_response_headers)
-                            except Exception as e:
-                                tool_response_headers = {}
-                                log.debug(e)
-
-                            print(tool_response_headers)
-                            print(type(tool_response_headers))
-
-                            if tool_response_headers and isinstance(
-                                tool_response_headers, dict
-                            ):
-                                content_disposition = tool_response_headers.get(
-                                    "Content-Disposition",
-                                    tool_response_headers.get(
-                                        "content-disposition", ""
-                                    ),
-                                )
-
-                                if "inline" in content_disposition:
-                                    content_type = tool_response_headers.get(
-                                        "Content-Type",
-                                        tool_response_headers.get("content-type", ""),
-                                    )
-                                    location = tool_response_headers.get(
-                                        "Location",
-                                        tool_response_headers.get("location", ""),
-                                    )
-
-                                    if "text/html" in content_type:
-                                        # Display as iframe embed
-                                        tool_result_embeds.append(tool_result)
-                                        tool_result = {
-                                            "status": "success",
-                                            "code": "ui_component",
-                                            "message": "Embedded UI result is active and visible to the user.",
-                                        }
-                                    elif location:
-                                        tool_result_embeds.append(location)
-                                        tool_result = {
-                                            "status": "success",
-                                            "code": "ui_component",
-                                            "message": "Embedded UI result is active and visible to the user.",
-                                        }
-
-                        tool_result_files = []
-                        if isinstance(tool_result, list):
-                            if tool.get("type") == "mcp":  # MCP
-                                tool_response = []
-                                for item in tool_result:
-                                    if isinstance(item, dict):
-                                        if item.get("type") == "text":
-                                            text = item.get("text", "")
-                                            if isinstance(text, str):
-                                                try:
-                                                    text = json.loads(text)
-                                                except json.JSONDecodeError:
-                                                    pass
-                                            tool_response.append(text)
-                                        elif item.get("type") in ["image", "audio"]:
-                                            file_url = get_file_url_from_base64(
-                                                request,
-                                                f"data:{item.get('mimeType')};base64,{item.get('data', item.get('blob', ''))}",
-                                                {
-                                                    "chat_id": metadata.get(
-                                                        "chat_id", None
-                                                    ),
-                                                    "message_id": metadata.get(
-                                                        "message_id", None
-                                                    ),
-                                                    "session_id": metadata.get(
-                                                        "session_id", None
-                                                    ),
-                                                    "result": item,
-                                                },
-                                                user,
-                                            )
-
-                                            tool_result_files.append(
-                                                {
-                                                    "type": item.get("type", "data"),
-                                                    "url": file_url,
-                                                }
-                                            )
-                                tool_result = (
-                                    tool_response[0]
-                                    if len(tool_response) == 1
-                                    else tool_response
-                                )
-                            else:  # OpenAPI
-                                for item in tool_result:
-                                    # check if string
-                                    if isinstance(item, str) and item.startswith(
-                                        "data:"
-                                    ):
-                                        tool_result_files.append(
-                                            {
-                                                "type": "data",
-                                                "content": item,
-                                            }
-                                        )
-                                        tool_result.remove(item)
-
-                        if tool_result_files:
-                            if not isinstance(tool_result, list):
-                                tool_result = [
-                                    tool_result,
-                                ]
-
-                            for file in tool_result_files:
-                                tool_result.append(
-                                    {
-                                        "type": file.get("type", "data"),
-                                        "content": "Result is being displayed as a file.",
-                                    }
-                                )
-
-                        if isinstance(tool_result, dict) or isinstance(
-                            tool_result, list
-                        ):
-                            tool_result = json.dumps(
-                                tool_result, indent=2, ensure_ascii=False
+                        tool_result, tool_result_files, tool_result_embeds = (
+                            process_tool_result(
+                                request,
+                                tool_function_name,
+                                tool_result,
+                                tool_type,
+                                direct_tool,
+                                metadata,
+                                user,
                             )
+                        )
 
                         results.append(
                             {
@@ -2800,7 +2733,6 @@ async def process_chat_response(
                         )
 
                     content_blocks[-1]["results"] = results
-
                     content_blocks.append(
                         {
                             "type": "text",