Răsfoiți Sursa

Merge pull request #14682 from olivier-lacroix/genai-tool-function

refactor: Improve tool callable generation to allow for genai native function call
Tim Jaeryang Baek 1 lună în urmă
părinte
comite
96643f5b6d
1 a modificat fișierele cu 26 adăugiri și 8 ștergeri
  1. 26 8
      backend/open_webui/utils/tools.py

+ 26 - 8
backend/open_webui/utils/tools.py

@@ -57,16 +57,34 @@ def get_async_tool_function_and_apply_extra_params(
     extra_params = {k: v for k, v in extra_params.items() if k in sig.parameters}
     partial_func = partial(function, **extra_params)
 
+    # Remove the 'frozen' keyword arguments from the signature
+    # python-genai uses the signature to infer the tool properties for native function calling
+    parameters = []
+    for name, parameter in sig.parameters.items():
+        # Exclude keyword arguments that are frozen
+        if name in extra_params:
+            continue
+        # Keep remaining parameters
+        parameters.append(parameter)
+
+    new_sig = inspect.Signature(
+        parameters=parameters, return_annotation=sig.return_annotation
+    )
+
     if inspect.iscoroutinefunction(function):
-        update_wrapper(partial_func, function)
-        return partial_func
+        # wrap the functools.partial as python-genai has trouble with it
+        # https://github.com/googleapis/python-genai/issues/907
+        async def new_function(*args, **kwargs):
+            return await partial_func(*args, **kwargs)
     else:
-        # Make it a coroutine function
+        # Make it a coroutine function when it is not already
         async def new_function(*args, **kwargs):
             return partial_func(*args, **kwargs)
 
-        update_wrapper(new_function, function)
-        return new_function
+    update_wrapper(new_function, function)
+    new_function.__signature__ = new_sig
+
+    return new_function
 
 
 async def get_tools(
@@ -293,15 +311,15 @@ def convert_function_to_pydantic_model(func: Callable) -> type[BaseModel]:
 
     field_defs = {}
     for name, param in parameters.items():
-
         type_hint = type_hints.get(name, Any)
         default_value = param.default if param.default is not param.empty else ...
 
         param_description = function_param_descriptions.get(name, None)
 
         if param_description:
-            field_defs[name] = type_hint, Field(
-                default_value, description=param_description
+            field_defs[name] = (
+                type_hint,
+                Field(default_value, description=param_description),
             )
         else:
             field_defs[name] = type_hint, default_value