tools.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477
  1. import inspect
  2. import logging
  3. import re
  4. import inspect
  5. import aiohttp
  6. import asyncio
  7. from typing import Any, Awaitable, Callable, get_type_hints, Dict, List, Union
  8. from functools import update_wrapper, partial
  9. from fastapi import Request
  10. from pydantic import BaseModel, Field, create_model
  11. from langchain_core.utils.function_calling import convert_to_openai_function
  12. from open_webui.models.tools import Tools
  13. from open_webui.models.users import UserModel
  14. from open_webui.utils.plugin import load_tools_module_by_id
  15. log = logging.getLogger(__name__)
  16. def apply_extra_params_to_tool_function(
  17. function: Callable, extra_params: dict
  18. ) -> Callable[..., Awaitable]:
  19. sig = inspect.signature(function)
  20. extra_params = {k: v for k, v in extra_params.items() if k in sig.parameters}
  21. partial_func = partial(function, **extra_params)
  22. if inspect.iscoroutinefunction(function):
  23. update_wrapper(partial_func, function)
  24. return partial_func
  25. async def new_function(*args, **kwargs):
  26. return partial_func(*args, **kwargs)
  27. update_wrapper(new_function, function)
  28. return new_function
  29. # Mutation on extra_params
  30. def get_tools(
  31. request: Request, tool_ids: list[str], user: UserModel, extra_params: dict
  32. ) -> dict[str, dict]:
  33. tools_dict = {}
  34. for tool_id in tool_ids:
  35. tools = Tools.get_tool_by_id(tool_id)
  36. if tools is None:
  37. continue
  38. module = request.app.state.TOOLS.get(tool_id, None)
  39. if module is None:
  40. module, _ = load_tools_module_by_id(tool_id)
  41. request.app.state.TOOLS[tool_id] = module
  42. extra_params["__id__"] = tool_id
  43. if hasattr(module, "valves") and hasattr(module, "Valves"):
  44. valves = Tools.get_tool_valves_by_id(tool_id) or {}
  45. module.valves = module.Valves(**valves)
  46. if hasattr(module, "UserValves"):
  47. extra_params["__user__"]["valves"] = module.UserValves( # type: ignore
  48. **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
  49. )
  50. for spec in tools.specs:
  51. # TODO: Fix hack for OpenAI API
  52. # Some times breaks OpenAI but others don't. Leaving the comment
  53. for val in spec.get("parameters", {}).get("properties", {}).values():
  54. if val["type"] == "str":
  55. val["type"] = "string"
  56. # Remove internal parameters
  57. spec["parameters"]["properties"] = {
  58. key: val
  59. for key, val in spec["parameters"]["properties"].items()
  60. if not key.startswith("__")
  61. }
  62. function_name = spec["name"]
  63. # convert to function that takes only model params and inserts custom params
  64. original_func = getattr(module, function_name)
  65. callable = apply_extra_params_to_tool_function(original_func, extra_params)
  66. if callable.__doc__ and callable.__doc__.strip() != "":
  67. s = re.split(":(param|return)", callable.__doc__, 1)
  68. spec["description"] = s[0]
  69. else:
  70. spec["description"] = function_name
  71. # TODO: This needs to be a pydantic model
  72. tool_dict = {
  73. "spec": spec,
  74. "callable": callable,
  75. "toolkit_id": tool_id,
  76. "pydantic_model": function_to_pydantic_model(callable),
  77. # Misc info
  78. "file_handler": hasattr(module, "file_handler") and module.file_handler,
  79. "citation": hasattr(module, "citation") and module.citation,
  80. }
  81. # TODO: if collision, prepend toolkit name
  82. if function_name in tools_dict:
  83. log.warning(f"Tool {function_name} already exists in another tools!")
  84. log.warning(f"Collision between {tools} and {tool_id}.")
  85. log.warning(f"Discarding {tools}.{function_name}")
  86. else:
  87. tools_dict[function_name] = tool_dict
  88. return tools_dict
  89. def parse_description(docstring: str | None) -> str:
  90. """
  91. Parse a function's docstring to extract the description.
  92. Args:
  93. docstring (str): The docstring to parse.
  94. Returns:
  95. str: The description.
  96. """
  97. if not docstring:
  98. return ""
  99. lines = [line.strip() for line in docstring.strip().split("\n")]
  100. description_lines: list[str] = []
  101. for line in lines:
  102. if re.match(r":param", line) or re.match(r":return", line):
  103. break
  104. description_lines.append(line)
  105. return "\n".join(description_lines)
  106. def parse_docstring(docstring):
  107. """
  108. Parse a function's docstring to extract parameter descriptions in reST format.
  109. Args:
  110. docstring (str): The docstring to parse.
  111. Returns:
  112. dict: A dictionary where keys are parameter names and values are descriptions.
  113. """
  114. if not docstring:
  115. return {}
  116. # Regex to match `:param name: description` format
  117. param_pattern = re.compile(r":param (\w+):\s*(.+)")
  118. param_descriptions = {}
  119. for line in docstring.splitlines():
  120. match = param_pattern.match(line.strip())
  121. if not match:
  122. continue
  123. param_name, param_description = match.groups()
  124. if param_name.startswith("__"):
  125. continue
  126. param_descriptions[param_name] = param_description
  127. return param_descriptions
  128. def function_to_pydantic_model(func: Callable) -> type[BaseModel]:
  129. """
  130. Converts a Python function's type hints and docstring to a Pydantic model,
  131. including support for nested types, default values, and descriptions.
  132. Args:
  133. func: The function whose type hints and docstring should be converted.
  134. model_name: The name of the generated Pydantic model.
  135. Returns:
  136. A Pydantic model class.
  137. """
  138. type_hints = get_type_hints(func)
  139. signature = inspect.signature(func)
  140. parameters = signature.parameters
  141. docstring = func.__doc__
  142. descriptions = parse_docstring(docstring)
  143. tool_description = parse_description(docstring)
  144. field_defs = {}
  145. for name, param in parameters.items():
  146. type_hint = type_hints.get(name, Any)
  147. default_value = param.default if param.default is not param.empty else ...
  148. description = descriptions.get(name, None)
  149. if not description:
  150. field_defs[name] = type_hint, default_value
  151. continue
  152. field_defs[name] = type_hint, Field(default_value, description=description)
  153. model = create_model(func.__name__, **field_defs)
  154. model.__doc__ = tool_description
  155. return model
  156. def get_callable_attributes(tool: object) -> list[Callable]:
  157. return [
  158. getattr(tool, func)
  159. for func in dir(tool)
  160. if callable(getattr(tool, func))
  161. and not func.startswith("__")
  162. and not inspect.isclass(getattr(tool, func))
  163. ]
  164. def get_tools_specs(tool_class: object) -> list[dict]:
  165. function_list = get_callable_attributes(tool_class)
  166. models = map(function_to_pydantic_model, function_list)
  167. return [convert_to_openai_function(tool) for tool in models]
  168. import copy
  169. def resolve_schema(schema, components):
  170. """
  171. Recursively resolves a JSON schema using OpenAPI components.
  172. """
  173. if not schema:
  174. return {}
  175. if "$ref" in schema:
  176. ref_path = schema["$ref"]
  177. ref_parts = ref_path.strip("#/").split("/")
  178. resolved = components
  179. for part in ref_parts[1:]: # Skip the initial 'components'
  180. resolved = resolved.get(part, {})
  181. return resolve_schema(resolved, components)
  182. resolved_schema = copy.deepcopy(schema)
  183. # Recursively resolve inner schemas
  184. if "properties" in resolved_schema:
  185. for prop, prop_schema in resolved_schema["properties"].items():
  186. resolved_schema["properties"][prop] = resolve_schema(
  187. prop_schema, components
  188. )
  189. if "items" in resolved_schema:
  190. resolved_schema["items"] = resolve_schema(resolved_schema["items"], components)
  191. return resolved_schema
  192. def convert_openapi_to_tool_payload(openapi_spec):
  193. """
  194. Converts an OpenAPI specification into a custom tool payload structure.
  195. Args:
  196. openapi_spec (dict): The OpenAPI specification as a Python dict.
  197. Returns:
  198. list: A list of tool payloads.
  199. """
  200. tool_payload = []
  201. for path, methods in openapi_spec.get("paths", {}).items():
  202. for method, operation in methods.items():
  203. tool = {
  204. "type": "function",
  205. "name": operation.get("operationId"),
  206. "description": operation.get("summary", "No description available."),
  207. "parameters": {"type": "object", "properties": {}, "required": []},
  208. }
  209. # Extract path and query parameters
  210. for param in operation.get("parameters", []):
  211. param_name = param["name"]
  212. param_schema = param.get("schema", {})
  213. tool["parameters"]["properties"][param_name] = {
  214. "type": param_schema.get("type"),
  215. "description": param_schema.get("description", ""),
  216. }
  217. if param.get("required"):
  218. tool["parameters"]["required"].append(param_name)
  219. # Extract and resolve requestBody if available
  220. request_body = operation.get("requestBody")
  221. if request_body:
  222. content = request_body.get("content", {})
  223. json_schema = content.get("application/json", {}).get("schema")
  224. if json_schema:
  225. resolved_schema = resolve_schema(
  226. json_schema, openapi_spec.get("components", {})
  227. )
  228. if resolved_schema.get("properties"):
  229. tool["parameters"]["properties"].update(
  230. resolved_schema["properties"]
  231. )
  232. if "required" in resolved_schema:
  233. tool["parameters"]["required"] = list(
  234. set(
  235. tool["parameters"]["required"]
  236. + resolved_schema["required"]
  237. )
  238. )
  239. elif resolved_schema.get("type") == "array":
  240. tool["parameters"] = resolved_schema # special case for array
  241. tool_payload.append(tool)
  242. return tool_payload
  243. async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]:
  244. headers = {
  245. "Accept": "application/json",
  246. "Content-Type": "application/json",
  247. }
  248. if token:
  249. headers["Authorization"] = f"Bearer {token}"
  250. error = None
  251. try:
  252. async with aiohttp.ClientSession() as session:
  253. async with session.get(url, headers=headers) as response:
  254. if response.status != 200:
  255. error_body = await response.json()
  256. raise Exception(error_body)
  257. res = await response.json()
  258. except Exception as err:
  259. print("Error:", err)
  260. if isinstance(err, dict) and "detail" in err:
  261. error = err["detail"]
  262. else:
  263. error = str(err)
  264. raise Exception(error)
  265. data = {
  266. "openapi": res,
  267. "info": res.get("info", {}),
  268. "specs": convert_openapi_to_tool_payload(res),
  269. }
  270. print("Fetched data:", data)
  271. return data
  272. async def get_tool_servers_data(servers: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
  273. enabled_servers = [
  274. server for server in servers if server.get("config", {}).get("enable")
  275. ]
  276. urls = [
  277. (
  278. server,
  279. f"{server.get('url')}/{server.get('path', 'openapi.json')}",
  280. server.get("key", ""),
  281. )
  282. for server in enabled_servers
  283. ]
  284. tasks = [get_tool_server_data(token, url) for _, url, token in urls]
  285. results: List[Dict[str, Any]] = []
  286. responses = await asyncio.gather(*tasks, return_exceptions=True)
  287. for (server, _, _), response in zip(urls, responses):
  288. if isinstance(response, Exception):
  289. url_path = server.get("path", "openapi.json")
  290. full_url = f"{server.get('url')}/{url_path}"
  291. print(f"Failed to connect to {full_url} OpenAPI tool server")
  292. else:
  293. results.append(
  294. {
  295. "url": server.get("url"),
  296. "openapi": response["openapi"],
  297. "info": response["info"],
  298. "specs": response["specs"],
  299. }
  300. )
  301. return results
  302. async def execute_tool_server(
  303. token: str, url: str, name: str, params: Dict[str, Any], server_data: Dict[str, Any]
  304. ) -> Any:
  305. error = None
  306. try:
  307. openapi = server_data.get("openapi", {})
  308. paths = openapi.get("paths", {})
  309. matching_route = None
  310. for route_path, methods in paths.items():
  311. for http_method, operation in methods.items():
  312. if isinstance(operation, dict) and operation.get("operationId") == name:
  313. matching_route = (route_path, methods)
  314. break
  315. if matching_route:
  316. break
  317. if not matching_route:
  318. raise Exception(f"No matching route found for operationId: {name}")
  319. route_path, methods = matching_route
  320. method_entry = None
  321. for http_method, operation in methods.items():
  322. if operation.get("operationId") == name:
  323. method_entry = (http_method.lower(), operation)
  324. break
  325. if not method_entry:
  326. raise Exception(f"No matching method found for operationId: {name}")
  327. http_method, operation = method_entry
  328. path_params = {}
  329. query_params = {}
  330. body_params = {}
  331. for param in operation.get("parameters", []):
  332. param_name = param["name"]
  333. param_in = param["in"]
  334. if param_name in params:
  335. if param_in == "path":
  336. path_params[param_name] = params[param_name]
  337. elif param_in == "query":
  338. query_params[param_name] = params[param_name]
  339. final_url = f"{url}{route_path}"
  340. for key, value in path_params.items():
  341. final_url = final_url.replace(f"{{{key}}}", str(value))
  342. if query_params:
  343. query_string = "&".join(f"{k}={v}" for k, v in query_params.items())
  344. final_url = f"{final_url}?{query_string}"
  345. if operation.get("requestBody", {}).get("content"):
  346. if params:
  347. body_params = params
  348. else:
  349. raise Exception(
  350. f"Request body expected for operation '{name}' but none found."
  351. )
  352. headers = {"Content-Type": "application/json"}
  353. if token:
  354. headers["Authorization"] = f"Bearer {token}"
  355. async with aiohttp.ClientSession() as session:
  356. request_method = getattr(session, http_method.lower())
  357. if http_method in ["post", "put", "patch"]:
  358. async with request_method(
  359. final_url, json=body_params, headers=headers
  360. ) as response:
  361. if response.status >= 400:
  362. text = await response.text()
  363. raise Exception(f"HTTP error {response.status}: {text}")
  364. return await response.json()
  365. else:
  366. async with request_method(final_url, headers=headers) as response:
  367. if response.status >= 400:
  368. text = await response.text()
  369. raise Exception(f"HTTP error {response.status}: {text}")
  370. return await response.json()
  371. except Exception as err:
  372. error = str(err)
  373. print("API Request Error:", error)
  374. return {"error": error}