Prechádzať zdrojové kódy

enh/fix: update extra params for native function calling

Co-Authored-By: Jacob Leksan <63938553+jmleksan@users.noreply.github.com>
Timothy Jaeryang Baek 3 mesiacov pred
rodič
commit
f5c7152a6b

+ 11 - 2
backend/open_webui/utils/middleware.py

@@ -91,7 +91,7 @@ from open_webui.utils.misc import (
     convert_logit_bias_input_to_json,
     get_content_from_message,
 )
-from open_webui.utils.tools import get_tools
+from open_webui.utils.tools import get_tools, get_updated_tool_function
 from open_webui.utils.plugin import load_function_module_by_id
 from open_webui.utils.filter import (
     get_sorted_filter_ids,
@@ -2838,7 +2838,16 @@ async def process_chat_response(
                                     )
 
                                 else:
-                                    tool_function = tool["callable"]
+                                    tool_function = await get_updated_tool_function(
+                                        function=tool["callable"],
+                                        extra_params={
+                                            "__messages__": form_data.get(
+                                                "messages", []
+                                            ),
+                                            "__files__": metadata.get("files", []),
+                                        },
+                                    )
+
                                     tool_result = await tool_function(
                                         **tool_function_params
                                     )

+ 17 - 0
backend/open_webui/utils/tools.py

@@ -85,9 +85,26 @@ def get_async_tool_function_and_apply_extra_params(
     update_wrapper(new_function, function)
     new_function.__signature__ = new_sig
 
+    new_function.__function__ = function  # type: ignore
+    new_function.__extra_params__ = extra_params  # type: ignore
+
     return new_function
 
 
+async def get_updated_tool_function(function: Callable, extra_params: dict):
+    # Get the original function and merge updated params
+    __function__ = getattr(function, "__function__", None)
+    __extra_params__ = getattr(function, "__extra_params__", None)
+
+    if __function__ is not None and __extra_params__ is not None:
+        return await get_async_tool_function_and_apply_extra_params(
+            __function__,
+            {**__extra_params__, **extra_params},
+        )
+
+    return function
+
+
 async def get_tools(
     request: Request, tool_ids: list[str], user: UserModel, extra_params: dict
 ) -> dict[str, dict]: