Timothy Jaeryang Baek 2 months ago
parent
commit
c9e9ce931b
2 changed files with 27 additions and 4 deletions
  1. 8 2
      backend/open_webui/utils/middleware.py
  2. 19 2
      backend/open_webui/utils/tools.py

+ 8 - 2
backend/open_webui/utils/middleware.py

@@ -227,7 +227,9 @@ async def chat_completion_tools_handler(
                 if isinstance(tool_result, str):
                 if isinstance(tool_result, str):
                     tool = tools[tool_function_name]
                     tool = tools[tool_function_name]
                     tool_id = tool.get("toolkit_id", "")
                     tool_id = tool.get("toolkit_id", "")
-                    if tool.get("citation", False) or tool.get("direct", False):
+                    if tool.get("metadata", {}).get("citation", False) or tool.get(
+                        "direct", False
+                    ):
 
 
                         sources.append(
                         sources.append(
                             {
                             {
@@ -267,7 +269,11 @@ async def chat_completion_tools_handler(
                             }
                             }
                         )
                         )
 
 
-                    if tools[tool_function_name].get("file_handler", False):
+                    if (
+                        tools[tool_function_name]
+                        .get("metadata", {})
+                        .get("file_handler", False)
+                    ):
                         skip_files = True
                         skip_files = True
 
 
             # check if "tool_calls" in result
             # check if "tool_calls" in result

+ 19 - 2
backend/open_webui/utils/tools.py

@@ -47,6 +47,20 @@ def get_tools(
     for tool_id in tool_ids:
     for tool_id in tool_ids:
         tools = Tools.get_tool_by_id(tool_id)
         tools = Tools.get_tool_by_id(tool_id)
         if tools is None:
         if tools is None:
+
+            tool_dict = {
+                "spec": spec,
+                "callable": callable,
+                "toolkit_id": tool_id,
+                "pydantic_model": function_to_pydantic_model(callable),
+                # Misc info
+                "metadata": {
+                    "file_handler": hasattr(module, "file_handler")
+                    and module.file_handler,
+                    "citation": hasattr(module, "citation") and module.citation,
+                },
+            }
+
             continue
             continue
 
 
         module = request.app.state.TOOLS.get(tool_id, None)
         module = request.app.state.TOOLS.get(tool_id, None)
@@ -97,8 +111,11 @@ def get_tools(
                 "toolkit_id": tool_id,
                 "toolkit_id": tool_id,
                 "pydantic_model": function_to_pydantic_model(callable),
                 "pydantic_model": function_to_pydantic_model(callable),
                 # Misc info
                 # Misc info
-                "file_handler": hasattr(module, "file_handler") and module.file_handler,
-                "citation": hasattr(module, "citation") and module.citation,
+                "metadata": {
+                    "file_handler": hasattr(module, "file_handler")
+                    and module.file_handler,
+                    "citation": hasattr(module, "citation") and module.citation,
+                },
             }
             }
 
 
             # TODO: if collision, prepend toolkit name
             # TODO: if collision, prepend toolkit name