123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813 |
- import inspect
- import logging
- import re
- import inspect
- import aiohttp
- import asyncio
- import yaml
- import json
- from pydantic import BaseModel
- from pydantic.fields import FieldInfo
- from typing import (
- Any,
- Awaitable,
- Callable,
- get_type_hints,
- get_args,
- get_origin,
- Dict,
- List,
- Tuple,
- Union,
- Optional,
- Type,
- )
- from functools import update_wrapper, partial
- from fastapi import Request
- from pydantic import BaseModel, Field, create_model
- from langchain_core.utils.function_calling import (
- convert_to_openai_function as convert_pydantic_model_to_openai_function_spec,
- )
- from open_webui.models.tools import Tools
- from open_webui.models.users import UserModel
- from open_webui.utils.plugin import load_tool_module_by_id
- from open_webui.env import (
- SRC_LOG_LEVELS,
- AIOHTTP_CLIENT_TIMEOUT,
- AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA,
- AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL,
- )
- import copy
- log = logging.getLogger(__name__)
- log.setLevel(SRC_LOG_LEVELS["MODELS"])
- def get_async_tool_function_and_apply_extra_params(
- function: Callable, extra_params: dict
- ) -> Callable[..., Awaitable]:
- sig = inspect.signature(function)
- extra_params = {k: v for k, v in extra_params.items() if k in sig.parameters}
- partial_func = partial(function, **extra_params)
- # Remove the 'frozen' keyword arguments from the signature
- # python-genai uses the signature to infer the tool properties for native function calling
- parameters = []
- for name, parameter in sig.parameters.items():
- # Exclude keyword arguments that are frozen
- if name in extra_params:
- continue
- # Keep remaining parameters
- parameters.append(parameter)
- new_sig = inspect.Signature(
- parameters=parameters, return_annotation=sig.return_annotation
- )
- if inspect.iscoroutinefunction(function):
- # wrap the functools.partial as python-genai has trouble with it
- # https://github.com/googleapis/python-genai/issues/907
- async def new_function(*args, **kwargs):
- return await partial_func(*args, **kwargs)
- else:
- # Make it a coroutine function when it is not already
- async def new_function(*args, **kwargs):
- return partial_func(*args, **kwargs)
- update_wrapper(new_function, function)
- new_function.__signature__ = new_sig
- return new_function
- async def get_tools(
- request: Request, tool_ids: list[str], user: UserModel, extra_params: dict
- ) -> dict[str, dict]:
- tools_dict = {}
- for tool_id in tool_ids:
- tool = Tools.get_tool_by_id(tool_id)
- if tool is None:
- if tool_id.startswith("server:"):
- splits = tool_id.split(":")
- if len(splits) == 2:
- type = "openapi"
- server_id = splits[1]
- elif len(splits) == 3:
- type = splits[1]
- server_id = splits[2]
- server_id_splits = server_id.split("|")
- if len(server_id_splits) == 2:
- server_id = server_id_splits[0]
- function_names = server_id_splits[1].split(",")
- if type == "openapi":
- tool_server_data = None
- for server in await get_tool_servers(request):
- if server["id"] == server_id:
- tool_server_data = server
- break
- if tool_server_data is None:
- log.warning(f"Tool server data not found for {server_id}")
- continue
- tool_server_idx = tool_server_data.get("idx", 0)
- tool_server_connection = (
- request.app.state.config.TOOL_SERVER_CONNECTIONS[
- tool_server_idx
- ]
- )
- specs = tool_server_data.get("specs", [])
- for spec in specs:
- function_name = spec["name"]
- auth_type = tool_server_connection.get("auth_type", "bearer")
- cookies = {}
- headers = {}
- if auth_type == "bearer":
- headers["Authorization"] = (
- f"Bearer {tool_server_connection.get('key', '')}"
- )
- elif auth_type == "none":
- # No authentication
- pass
- elif auth_type == "session":
- cookies = request.cookies
- headers["Authorization"] = (
- f"Bearer {request.state.token.credentials}"
- )
- elif auth_type == "system_oauth":
- cookies = request.cookies
- oauth_token = extra_params.get("__oauth_token__", None)
- if oauth_token:
- headers["Authorization"] = (
- f"Bearer {oauth_token.get('access_token', '')}"
- )
- headers["Content-Type"] = "application/json"
- def make_tool_function(
- function_name, tool_server_data, headers
- ):
- async def tool_function(**kwargs):
- return await execute_tool_server(
- url=tool_server_data["url"],
- headers=headers,
- cookies=cookies,
- name=function_name,
- params=kwargs,
- server_data=tool_server_data,
- )
- return tool_function
- tool_function = make_tool_function(
- function_name, tool_server_data, headers
- )
- callable = get_async_tool_function_and_apply_extra_params(
- tool_function,
- {},
- )
- tool_dict = {
- "tool_id": tool_id,
- "callable": callable,
- "spec": spec,
- # Misc info
- "type": "external",
- }
- # Handle function name collisions
- while function_name in tools_dict:
- log.warning(
- f"Tool {function_name} already exists in another tools!"
- )
- # Prepend server ID to function name
- function_name = f"{server_id}_{function_name}"
- tools_dict[function_name] = tool_dict
- else:
- continue
- else:
- continue
- else:
- module = request.app.state.TOOLS.get(tool_id, None)
- if module is None:
- module, _ = load_tool_module_by_id(tool_id)
- request.app.state.TOOLS[tool_id] = module
- extra_params["__id__"] = tool_id
- # Set valves for the tool
- if hasattr(module, "valves") and hasattr(module, "Valves"):
- valves = Tools.get_tool_valves_by_id(tool_id) or {}
- module.valves = module.Valves(**valves)
- if hasattr(module, "UserValves"):
- extra_params["__user__"]["valves"] = module.UserValves( # type: ignore
- **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
- )
- for spec in tool.specs:
- # TODO: Fix hack for OpenAI API
- # Some times breaks OpenAI but others don't. Leaving the comment
- for val in spec.get("parameters", {}).get("properties", {}).values():
- if val.get("type") == "str":
- val["type"] = "string"
- # Remove internal reserved parameters (e.g. __id__, __user__)
- spec["parameters"]["properties"] = {
- key: val
- for key, val in spec["parameters"]["properties"].items()
- if not key.startswith("__")
- }
- # convert to function that takes only model params and inserts custom params
- function_name = spec["name"]
- tool_function = getattr(module, function_name)
- callable = get_async_tool_function_and_apply_extra_params(
- tool_function, extra_params
- )
- # TODO: Support Pydantic models as parameters
- if callable.__doc__ and callable.__doc__.strip() != "":
- s = re.split(":(param|return)", callable.__doc__, 1)
- spec["description"] = s[0]
- else:
- spec["description"] = function_name
- tool_dict = {
- "tool_id": tool_id,
- "callable": callable,
- "spec": spec,
- # Misc info
- "metadata": {
- "file_handler": hasattr(module, "file_handler")
- and module.file_handler,
- "citation": hasattr(module, "citation") and module.citation,
- },
- }
- # Handle function name collisions
- while function_name in tools_dict:
- log.warning(
- f"Tool {function_name} already exists in another tools!"
- )
- # Prepend tool ID to function name
- function_name = f"{tool_id}_{function_name}"
- tools_dict[function_name] = tool_dict
- return tools_dict
- def parse_description(docstring: str | None) -> str:
- """
- Parse a function's docstring to extract the description.
- Args:
- docstring (str): The docstring to parse.
- Returns:
- str: The description.
- """
- if not docstring:
- return ""
- lines = [line.strip() for line in docstring.strip().split("\n")]
- description_lines: list[str] = []
- for line in lines:
- if re.match(r":param", line) or re.match(r":return", line):
- break
- description_lines.append(line)
- return "\n".join(description_lines)
- def parse_docstring(docstring):
- """
- Parse a function's docstring to extract parameter descriptions in reST format.
- Args:
- docstring (str): The docstring to parse.
- Returns:
- dict: A dictionary where keys are parameter names and values are descriptions.
- """
- if not docstring:
- return {}
- # Regex to match `:param name: description` format
- param_pattern = re.compile(r":param (\w+):\s*(.+)")
- param_descriptions = {}
- for line in docstring.splitlines():
- match = param_pattern.match(line.strip())
- if not match:
- continue
- param_name, param_description = match.groups()
- if param_name.startswith("__"):
- continue
- param_descriptions[param_name] = param_description
- return param_descriptions
- def convert_function_to_pydantic_model(func: Callable) -> type[BaseModel]:
- """
- Converts a Python function's type hints and docstring to a Pydantic model,
- including support for nested types, default values, and descriptions.
- Args:
- func: The function whose type hints and docstring should be converted.
- model_name: The name of the generated Pydantic model.
- Returns:
- A Pydantic model class.
- """
- type_hints = get_type_hints(func)
- signature = inspect.signature(func)
- parameters = signature.parameters
- docstring = func.__doc__
- function_description = parse_description(docstring)
- function_param_descriptions = parse_docstring(docstring)
- field_defs = {}
- for name, param in parameters.items():
- type_hint = type_hints.get(name, Any)
- default_value = param.default if param.default is not param.empty else ...
- param_description = function_param_descriptions.get(name, None)
- if param_description:
- field_defs[name] = (
- type_hint,
- Field(default_value, description=param_description),
- )
- else:
- field_defs[name] = type_hint, default_value
- model = create_model(func.__name__, **field_defs)
- model.__doc__ = function_description
- return model
- def get_functions_from_tool(tool: object) -> list[Callable]:
- return [
- getattr(tool, func)
- for func in dir(tool)
- if callable(
- getattr(tool, func)
- ) # checks if the attribute is callable (a method or function).
- and not func.startswith(
- "__"
- ) # filters out special (dunder) methods like init, str, etc. — these are usually built-in functions of an object that you might not need to use directly.
- and not inspect.isclass(
- getattr(tool, func)
- ) # ensures that the callable is not a class itself, just a method or function.
- ]
- def get_tool_specs(tool_module: object) -> list[dict]:
- function_models = map(
- convert_function_to_pydantic_model, get_functions_from_tool(tool_module)
- )
- specs = [
- convert_pydantic_model_to_openai_function_spec(function_model)
- for function_model in function_models
- ]
- return specs
- def resolve_schema(schema, components):
- """
- Recursively resolves a JSON schema using OpenAPI components.
- """
- if not schema:
- return {}
- if "$ref" in schema:
- ref_path = schema["$ref"]
- ref_parts = ref_path.strip("#/").split("/")
- resolved = components
- for part in ref_parts[1:]: # Skip the initial 'components'
- resolved = resolved.get(part, {})
- return resolve_schema(resolved, components)
- resolved_schema = copy.deepcopy(schema)
- # Recursively resolve inner schemas
- if "properties" in resolved_schema:
- for prop, prop_schema in resolved_schema["properties"].items():
- resolved_schema["properties"][prop] = resolve_schema(
- prop_schema, components
- )
- if "items" in resolved_schema:
- resolved_schema["items"] = resolve_schema(resolved_schema["items"], components)
- return resolved_schema
- def convert_openapi_to_tool_payload(openapi_spec):
- """
- Converts an OpenAPI specification into a custom tool payload structure.
- Args:
- openapi_spec (dict): The OpenAPI specification as a Python dict.
- Returns:
- list: A list of tool payloads.
- """
- tool_payload = []
- for path, methods in openapi_spec.get("paths", {}).items():
- for method, operation in methods.items():
- if operation.get("operationId"):
- tool = {
- "name": operation.get("operationId"),
- "description": operation.get(
- "description",
- operation.get("summary", "No description available."),
- ),
- "parameters": {"type": "object", "properties": {}, "required": []},
- }
- # Extract path and query parameters
- for param in operation.get("parameters", []):
- param_name = param["name"]
- param_schema = param.get("schema", {})
- description = param_schema.get("description", "")
- if not description:
- description = param.get("description") or ""
- if param_schema.get("enum") and isinstance(
- param_schema.get("enum"), list
- ):
- description += (
- f". Possible values: {', '.join(param_schema.get('enum'))}"
- )
- param_property = {
- "type": param_schema.get("type"),
- "description": description,
- }
- # Include items property for array types (required by OpenAI)
- if param_schema.get("type") == "array" and "items" in param_schema:
- param_property["items"] = param_schema["items"]
- tool["parameters"]["properties"][param_name] = param_property
- if param.get("required"):
- tool["parameters"]["required"].append(param_name)
- # Extract and resolve requestBody if available
- request_body = operation.get("requestBody")
- if request_body:
- content = request_body.get("content", {})
- json_schema = content.get("application/json", {}).get("schema")
- if json_schema:
- resolved_schema = resolve_schema(
- json_schema, openapi_spec.get("components", {})
- )
- if resolved_schema.get("properties"):
- tool["parameters"]["properties"].update(
- resolved_schema["properties"]
- )
- if "required" in resolved_schema:
- tool["parameters"]["required"] = list(
- set(
- tool["parameters"]["required"]
- + resolved_schema["required"]
- )
- )
- elif resolved_schema.get("type") == "array":
- tool["parameters"] = (
- resolved_schema # special case for array
- )
- tool_payload.append(tool)
- return tool_payload
- async def set_tool_servers(request: Request):
- request.app.state.TOOL_SERVERS = await get_tool_servers_data(
- request.app.state.config.TOOL_SERVER_CONNECTIONS
- )
- if request.app.state.redis is not None:
- await request.app.state.redis.set(
- "tool_servers", json.dumps(request.app.state.TOOL_SERVERS)
- )
- return request.app.state.TOOL_SERVERS
- async def get_tool_servers(request: Request):
- tool_servers = []
- if request.app.state.redis is not None:
- try:
- tool_servers = json.loads(await request.app.state.redis.get("tool_servers"))
- request.app.state.TOOL_SERVERS = tool_servers
- except Exception as e:
- log.error(f"Error fetching tool_servers from Redis: {e}")
- if not tool_servers:
- tool_servers = await set_tool_servers(request)
- return tool_servers
- async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]:
- headers = {
- "Accept": "application/json",
- "Content-Type": "application/json",
- }
- if token:
- headers["Authorization"] = f"Bearer {token}"
- error = None
- try:
- timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA)
- async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
- async with session.get(
- url, headers=headers, ssl=AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL
- ) as response:
- if response.status != 200:
- error_body = await response.json()
- raise Exception(error_body)
- text_content = None
- # Check if URL ends with .yaml or .yml to determine format
- if url.lower().endswith((".yaml", ".yml")):
- text_content = await response.text()
- res = yaml.safe_load(text_content)
- else:
- text_content = await response.text()
- try:
- res = json.loads(text_content)
- except json.JSONDecodeError:
- try:
- res = yaml.safe_load(text_content)
- except Exception as e:
- raise e
- except Exception as err:
- log.exception(f"Could not fetch tool server spec from {url}")
- if isinstance(err, dict) and "detail" in err:
- error = err["detail"]
- else:
- error = str(err)
- raise Exception(error)
- log.debug(f"Fetched data: {res}")
- return res
- async def get_tool_servers_data(servers: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
- # Prepare list of enabled servers along with their original index
- tasks = []
- server_entries = []
- for idx, server in enumerate(servers):
- if (
- server.get("config", {}).get("enable")
- and server.get("type", "openapi") == "openapi"
- ):
- info = server.get("info", {})
- auth_type = server.get("auth_type", "bearer")
- token = None
- if auth_type == "bearer":
- token = server.get("key", "")
- elif auth_type == "none":
- # No authentication
- pass
- id = info.get("id")
- if not id:
- id = str(idx)
- server_url = server.get("url")
- spec_type = server.get("spec_type", "url")
- # Create async tasks to fetch data
- task = None
- if spec_type == "url":
- # Path (to OpenAPI spec URL) can be either a full URL or a path to append to the base URL
- openapi_path = server.get("path", "openapi.json")
- spec_url = get_tool_server_url(server_url, openapi_path)
- # Fetch from URL
- task = get_tool_server_data(token, spec_url)
- elif spec_type == "json" and server.get("spec", ""):
- # Use provided JSON spec
- spec_json = None
- try:
- spec_json = json.loads(server.get("spec", ""))
- except Exception as e:
- log.error(f"Error parsing JSON spec for tool server {id}: {e}")
- if spec_json:
- task = asyncio.sleep(
- 0,
- result=spec_json,
- )
- if task:
- tasks.append(task)
- server_entries.append((id, idx, server, server_url, info, token))
- # Execute tasks concurrently
- responses = await asyncio.gather(*tasks, return_exceptions=True)
- # Build final results with index and server metadata
- results = []
- for (id, idx, server, url, info, _), response in zip(server_entries, responses):
- if isinstance(response, Exception):
- log.error(f"Failed to connect to {url} OpenAPI tool server")
- continue
- response = {
- "openapi": response,
- "info": response.get("info", {}),
- "specs": convert_openapi_to_tool_payload(response),
- }
- openapi_data = response.get("openapi", {})
- if info and isinstance(openapi_data, dict):
- openapi_data["info"] = openapi_data.get("info", {})
- if "name" in info:
- openapi_data["info"]["title"] = info.get("name", "Tool Server")
- if "description" in info:
- openapi_data["info"]["description"] = info.get("description", "")
- results.append(
- {
- "id": str(id),
- "idx": idx,
- "url": server.get("url"),
- "openapi": openapi_data,
- "info": response.get("info"),
- "specs": response.get("specs"),
- }
- )
- return results
- async def execute_tool_server(
- url: str,
- headers: Dict[str, str],
- cookies: Dict[str, str],
- name: str,
- params: Dict[str, Any],
- server_data: Dict[str, Any],
- ) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]]]:
- error = None
- try:
- openapi = server_data.get("openapi", {})
- paths = openapi.get("paths", {})
- matching_route = None
- for route_path, methods in paths.items():
- for http_method, operation in methods.items():
- if isinstance(operation, dict) and operation.get("operationId") == name:
- matching_route = (route_path, methods)
- break
- if matching_route:
- break
- if not matching_route:
- raise Exception(f"No matching route found for operationId: {name}")
- route_path, methods = matching_route
- method_entry = None
- for http_method, operation in methods.items():
- if operation.get("operationId") == name:
- method_entry = (http_method.lower(), operation)
- break
- if not method_entry:
- raise Exception(f"No matching method found for operationId: {name}")
- http_method, operation = method_entry
- path_params = {}
- query_params = {}
- body_params = {}
- for param in operation.get("parameters", []):
- param_name = param["name"]
- param_in = param["in"]
- if param_name in params:
- if param_in == "path":
- path_params[param_name] = params[param_name]
- elif param_in == "query":
- query_params[param_name] = params[param_name]
- final_url = f"{url}{route_path}"
- for key, value in path_params.items():
- final_url = final_url.replace(f"{{{key}}}", str(value))
- if query_params:
- query_string = "&".join(f"{k}={v}" for k, v in query_params.items())
- final_url = f"{final_url}?{query_string}"
- if operation.get("requestBody", {}).get("content"):
- if params:
- body_params = params
- async with aiohttp.ClientSession(
- trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
- ) as session:
- request_method = getattr(session, http_method.lower())
- if http_method in ["post", "put", "patch"]:
- async with request_method(
- final_url,
- json=body_params,
- headers=headers,
- cookies=cookies,
- ssl=AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL,
- allow_redirects=False,
- ) as response:
- if response.status >= 400:
- text = await response.text()
- raise Exception(f"HTTP error {response.status}: {text}")
- try:
- response_data = await response.json()
- except Exception:
- response_data = await response.text()
- response_headers = response.headers
- return (response_data, response_headers)
- else:
- async with request_method(
- final_url,
- headers=headers,
- cookies=cookies,
- ssl=AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL,
- allow_redirects=False,
- ) as response:
- if response.status >= 400:
- text = await response.text()
- raise Exception(f"HTTP error {response.status}: {text}")
- try:
- response_data = await response.json()
- except Exception:
- response_data = await response.text()
- response_headers = response.headers
- return (response_data, response_headers)
- except Exception as err:
- error = str(err)
- log.exception(f"API Request Error: {error}")
- return ({"error": error}, None)
- def get_tool_server_url(url: Optional[str], path: str) -> str:
- """
- Build the full URL for a tool server, given a base url and a path.
- """
- if "://" in path:
- # If it contains "://", it's a full URL
- return path
- if not path.startswith("/"):
- # Ensure the path starts with a slash
- path = f"/{path}"
- return f"{url}{path}"
|