tools.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576
  1. import inspect
  2. import logging
  3. import re
  4. import inspect
  5. import aiohttp
  6. import asyncio
  7. import yaml
  8. from typing import Any, Awaitable, Callable, get_type_hints, Dict, List, Union, Optional
  9. from functools import update_wrapper, partial
  10. from fastapi import Request
  11. from pydantic import BaseModel, Field, create_model
  12. from langchain_core.utils.function_calling import (
  13. convert_to_openai_function as convert_pydantic_model_to_openai_function_spec,
  14. )
  15. from open_webui.models.tools import Tools
  16. from open_webui.models.users import UserModel
  17. from open_webui.utils.plugin import load_tool_module_by_id
  18. from open_webui.env import AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA
  19. import copy
  20. log = logging.getLogger(__name__)
  21. def get_async_tool_function_and_apply_extra_params(
  22. function: Callable, extra_params: dict
  23. ) -> Callable[..., Awaitable]:
  24. sig = inspect.signature(function)
  25. extra_params = {k: v for k, v in extra_params.items() if k in sig.parameters}
  26. partial_func = partial(function, **extra_params)
  27. if inspect.iscoroutinefunction(function):
  28. update_wrapper(partial_func, function)
  29. return partial_func
  30. else:
  31. # Make it a coroutine function
  32. async def new_function(*args, **kwargs):
  33. return partial_func(*args, **kwargs)
  34. update_wrapper(new_function, function)
  35. return new_function
  36. def get_tools(
  37. request: Request, tool_ids: list[str], user: UserModel, extra_params: dict
  38. ) -> dict[str, dict]:
  39. tools_dict = {}
  40. for tool_id in tool_ids:
  41. tool = Tools.get_tool_by_id(tool_id)
  42. if tool is None:
  43. if tool_id.startswith("server:"):
  44. server_idx = int(tool_id.split(":")[1])
  45. tool_server_connection = (
  46. request.app.state.config.TOOL_SERVER_CONNECTIONS[server_idx]
  47. )
  48. tool_server_data = None
  49. for server in request.app.state.TOOL_SERVERS:
  50. if server["idx"] == server_idx:
  51. tool_server_data = server
  52. break
  53. assert tool_server_data is not None
  54. specs = tool_server_data.get("specs", [])
  55. for spec in specs:
  56. function_name = spec["name"]
  57. auth_type = tool_server_connection.get("auth_type", "bearer")
  58. token = None
  59. if auth_type == "bearer":
  60. token = tool_server_connection.get("key", "")
  61. elif auth_type == "session":
  62. token = request.state.token.credentials
  63. def make_tool_function(function_name, token, tool_server_data):
  64. async def tool_function(**kwargs):
  65. print(
  66. f"Executing tool function {function_name} with params: {kwargs}"
  67. )
  68. return await execute_tool_server(
  69. token=token,
  70. url=tool_server_data["url"],
  71. name=function_name,
  72. params=kwargs,
  73. server_data=tool_server_data,
  74. )
  75. return tool_function
  76. tool_function = make_tool_function(
  77. function_name, token, tool_server_data
  78. )
  79. callable = get_async_tool_function_and_apply_extra_params(
  80. tool_function,
  81. {},
  82. )
  83. tool_dict = {
  84. "tool_id": tool_id,
  85. "callable": callable,
  86. "spec": spec,
  87. }
  88. # TODO: if collision, prepend toolkit name
  89. if function_name in tools_dict:
  90. log.warning(
  91. f"Tool {function_name} already exists in another tools!"
  92. )
  93. log.warning(f"Discarding {tool_id}.{function_name}")
  94. else:
  95. tools_dict[function_name] = tool_dict
  96. else:
  97. continue
  98. else:
  99. module = request.app.state.TOOLS.get(tool_id, None)
  100. if module is None:
  101. module, _ = load_tool_module_by_id(tool_id)
  102. request.app.state.TOOLS[tool_id] = module
  103. extra_params["__id__"] = tool_id
  104. # Set valves for the tool
  105. if hasattr(module, "valves") and hasattr(module, "Valves"):
  106. valves = Tools.get_tool_valves_by_id(tool_id) or {}
  107. module.valves = module.Valves(**valves)
  108. if hasattr(module, "UserValves"):
  109. extra_params["__user__"]["valves"] = module.UserValves( # type: ignore
  110. **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
  111. )
  112. for spec in tool.specs:
  113. # TODO: Fix hack for OpenAI API
  114. # Some times breaks OpenAI but others don't. Leaving the comment
  115. for val in spec.get("parameters", {}).get("properties", {}).values():
  116. if val["type"] == "str":
  117. val["type"] = "string"
  118. # Remove internal reserved parameters (e.g. __id__, __user__)
  119. spec["parameters"]["properties"] = {
  120. key: val
  121. for key, val in spec["parameters"]["properties"].items()
  122. if not key.startswith("__")
  123. }
  124. # convert to function that takes only model params and inserts custom params
  125. function_name = spec["name"]
  126. tool_function = getattr(module, function_name)
  127. callable = get_async_tool_function_and_apply_extra_params(
  128. tool_function, extra_params
  129. )
  130. # TODO: Support Pydantic models as parameters
  131. if callable.__doc__ and callable.__doc__.strip() != "":
  132. s = re.split(":(param|return)", callable.__doc__, 1)
  133. spec["description"] = s[0]
  134. else:
  135. spec["description"] = function_name
  136. tool_dict = {
  137. "tool_id": tool_id,
  138. "callable": callable,
  139. "spec": spec,
  140. # Misc info
  141. "metadata": {
  142. "file_handler": hasattr(module, "file_handler")
  143. and module.file_handler,
  144. "citation": hasattr(module, "citation") and module.citation,
  145. },
  146. }
  147. # TODO: if collision, prepend toolkit name
  148. if function_name in tools_dict:
  149. log.warning(
  150. f"Tool {function_name} already exists in another tools!"
  151. )
  152. log.warning(f"Discarding {tool_id}.{function_name}")
  153. else:
  154. tools_dict[function_name] = tool_dict
  155. return tools_dict
  156. def parse_description(docstring: str | None) -> str:
  157. """
  158. Parse a function's docstring to extract the description.
  159. Args:
  160. docstring (str): The docstring to parse.
  161. Returns:
  162. str: The description.
  163. """
  164. if not docstring:
  165. return ""
  166. lines = [line.strip() for line in docstring.strip().split("\n")]
  167. description_lines: list[str] = []
  168. for line in lines:
  169. if re.match(r":param", line) or re.match(r":return", line):
  170. break
  171. description_lines.append(line)
  172. return "\n".join(description_lines)
  173. def parse_docstring(docstring):
  174. """
  175. Parse a function's docstring to extract parameter descriptions in reST format.
  176. Args:
  177. docstring (str): The docstring to parse.
  178. Returns:
  179. dict: A dictionary where keys are parameter names and values are descriptions.
  180. """
  181. if not docstring:
  182. return {}
  183. # Regex to match `:param name: description` format
  184. param_pattern = re.compile(r":param (\w+):\s*(.+)")
  185. param_descriptions = {}
  186. for line in docstring.splitlines():
  187. match = param_pattern.match(line.strip())
  188. if not match:
  189. continue
  190. param_name, param_description = match.groups()
  191. if param_name.startswith("__"):
  192. continue
  193. param_descriptions[param_name] = param_description
  194. return param_descriptions
  195. def convert_function_to_pydantic_model(func: Callable) -> type[BaseModel]:
  196. """
  197. Converts a Python function's type hints and docstring to a Pydantic model,
  198. including support for nested types, default values, and descriptions.
  199. Args:
  200. func: The function whose type hints and docstring should be converted.
  201. model_name: The name of the generated Pydantic model.
  202. Returns:
  203. A Pydantic model class.
  204. """
  205. type_hints = get_type_hints(func)
  206. signature = inspect.signature(func)
  207. parameters = signature.parameters
  208. docstring = func.__doc__
  209. descriptions = parse_docstring(docstring)
  210. tool_description = parse_description(docstring)
  211. field_defs = {}
  212. for name, param in parameters.items():
  213. type_hint = type_hints.get(name, Any)
  214. default_value = param.default if param.default is not param.empty else ...
  215. description = descriptions.get(name, None)
  216. if not description:
  217. field_defs[name] = type_hint, default_value
  218. continue
  219. field_defs[name] = type_hint, Field(default_value, description=description)
  220. model = create_model(func.__name__, **field_defs)
  221. model.__doc__ = tool_description
  222. return model
  223. def get_functions_from_tool(tool: object) -> list[Callable]:
  224. return [
  225. getattr(tool, func)
  226. for func in dir(tool)
  227. if callable(
  228. getattr(tool, func)
  229. ) # checks if the attribute is callable (a method or function).
  230. and not func.startswith(
  231. "__"
  232. ) # filters out special (dunder) methods like init, str, etc. — these are usually built-in functions of an object that you might not need to use directly.
  233. and not inspect.isclass(
  234. getattr(tool, func)
  235. ) # ensures that the callable is not a class itself, just a method or function.
  236. ]
  237. def get_tool_specs(tool_module: object) -> list[dict]:
  238. function_models = map(
  239. convert_function_to_pydantic_model, get_functions_from_tool(tool_module)
  240. )
  241. return [
  242. convert_pydantic_model_to_openai_function_spec(function_model)
  243. for function_model in function_models
  244. ]
  245. def resolve_schema(schema, components):
  246. """
  247. Recursively resolves a JSON schema using OpenAPI components.
  248. """
  249. if not schema:
  250. return {}
  251. if "$ref" in schema:
  252. ref_path = schema["$ref"]
  253. ref_parts = ref_path.strip("#/").split("/")
  254. resolved = components
  255. for part in ref_parts[1:]: # Skip the initial 'components'
  256. resolved = resolved.get(part, {})
  257. return resolve_schema(resolved, components)
  258. resolved_schema = copy.deepcopy(schema)
  259. # Recursively resolve inner schemas
  260. if "properties" in resolved_schema:
  261. for prop, prop_schema in resolved_schema["properties"].items():
  262. resolved_schema["properties"][prop] = resolve_schema(
  263. prop_schema, components
  264. )
  265. if "items" in resolved_schema:
  266. resolved_schema["items"] = resolve_schema(resolved_schema["items"], components)
  267. return resolved_schema
  268. def convert_openapi_to_tool_payload(openapi_spec):
  269. """
  270. Converts an OpenAPI specification into a custom tool payload structure.
  271. Args:
  272. openapi_spec (dict): The OpenAPI specification as a Python dict.
  273. Returns:
  274. list: A list of tool payloads.
  275. """
  276. tool_payload = []
  277. for path, methods in openapi_spec.get("paths", {}).items():
  278. for method, operation in methods.items():
  279. tool = {
  280. "type": "function",
  281. "name": operation.get("operationId"),
  282. "description": operation.get(
  283. "description", operation.get("summary", "No description available.")
  284. ),
  285. "parameters": {"type": "object", "properties": {}, "required": []},
  286. }
  287. # Extract path and query parameters
  288. for param in operation.get("parameters", []):
  289. param_name = param["name"]
  290. param_schema = param.get("schema", {})
  291. tool["parameters"]["properties"][param_name] = {
  292. "type": param_schema.get("type"),
  293. "description": param_schema.get("description", ""),
  294. }
  295. if param.get("required"):
  296. tool["parameters"]["required"].append(param_name)
  297. # Extract and resolve requestBody if available
  298. request_body = operation.get("requestBody")
  299. if request_body:
  300. content = request_body.get("content", {})
  301. json_schema = content.get("application/json", {}).get("schema")
  302. if json_schema:
  303. resolved_schema = resolve_schema(
  304. json_schema, openapi_spec.get("components", {})
  305. )
  306. if resolved_schema.get("properties"):
  307. tool["parameters"]["properties"].update(
  308. resolved_schema["properties"]
  309. )
  310. if "required" in resolved_schema:
  311. tool["parameters"]["required"] = list(
  312. set(
  313. tool["parameters"]["required"]
  314. + resolved_schema["required"]
  315. )
  316. )
  317. elif resolved_schema.get("type") == "array":
  318. tool["parameters"] = resolved_schema # special case for array
  319. tool_payload.append(tool)
  320. return tool_payload
  321. async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]:
  322. headers = {
  323. "Accept": "application/json",
  324. "Content-Type": "application/json",
  325. }
  326. if token:
  327. headers["Authorization"] = f"Bearer {token}"
  328. error = None
  329. try:
  330. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA)
  331. async with aiohttp.ClientSession(timeout=timeout) as session:
  332. async with session.get(url, headers=headers) as response:
  333. if response.status != 200:
  334. error_body = await response.json()
  335. raise Exception(error_body)
  336. # Check if URL ends with .yaml or .yml to determine format
  337. if url.lower().endswith((".yaml", ".yml")):
  338. text_content = await response.text()
  339. res = yaml.safe_load(text_content)
  340. else:
  341. res = await response.json()
  342. except Exception as err:
  343. log.exception(f"Could not fetch tool server spec from {url}")
  344. if isinstance(err, dict) and "detail" in err:
  345. error = err["detail"]
  346. else:
  347. error = str(err)
  348. raise Exception(error)
  349. data = {
  350. "openapi": res,
  351. "info": res.get("info", {}),
  352. "specs": convert_openapi_to_tool_payload(res),
  353. }
  354. print("Fetched data:", data)
  355. return data
  356. async def get_tool_servers_data(
  357. servers: List[Dict[str, Any]], session_token: Optional[str] = None
  358. ) -> List[Dict[str, Any]]:
  359. # Prepare list of enabled servers along with their original index
  360. server_entries = []
  361. for idx, server in enumerate(servers):
  362. if server.get("config", {}).get("enable"):
  363. url_path = server.get("path", "openapi.json")
  364. full_url = f"{server.get('url')}/{url_path}"
  365. auth_type = server.get("auth_type", "bearer")
  366. token = None
  367. if auth_type == "bearer":
  368. token = server.get("key", "")
  369. elif auth_type == "session":
  370. token = session_token
  371. server_entries.append((idx, server, full_url, token))
  372. # Create async tasks to fetch data
  373. tasks = [get_tool_server_data(token, url) for (_, _, url, token) in server_entries]
  374. # Execute tasks concurrently
  375. responses = await asyncio.gather(*tasks, return_exceptions=True)
  376. # Build final results with index and server metadata
  377. results = []
  378. for (idx, server, url, _), response in zip(server_entries, responses):
  379. if isinstance(response, Exception):
  380. print(f"Failed to connect to {url} OpenAPI tool server")
  381. continue
  382. results.append(
  383. {
  384. "idx": idx,
  385. "url": server.get("url"),
  386. "openapi": response.get("openapi"),
  387. "info": response.get("info"),
  388. "specs": response.get("specs"),
  389. }
  390. )
  391. return results
  392. async def execute_tool_server(
  393. token: str, url: str, name: str, params: Dict[str, Any], server_data: Dict[str, Any]
  394. ) -> Any:
  395. error = None
  396. try:
  397. openapi = server_data.get("openapi", {})
  398. paths = openapi.get("paths", {})
  399. matching_route = None
  400. for route_path, methods in paths.items():
  401. for http_method, operation in methods.items():
  402. if isinstance(operation, dict) and operation.get("operationId") == name:
  403. matching_route = (route_path, methods)
  404. break
  405. if matching_route:
  406. break
  407. if not matching_route:
  408. raise Exception(f"No matching route found for operationId: {name}")
  409. route_path, methods = matching_route
  410. method_entry = None
  411. for http_method, operation in methods.items():
  412. if operation.get("operationId") == name:
  413. method_entry = (http_method.lower(), operation)
  414. break
  415. if not method_entry:
  416. raise Exception(f"No matching method found for operationId: {name}")
  417. http_method, operation = method_entry
  418. path_params = {}
  419. query_params = {}
  420. body_params = {}
  421. for param in operation.get("parameters", []):
  422. param_name = param["name"]
  423. param_in = param["in"]
  424. if param_name in params:
  425. if param_in == "path":
  426. path_params[param_name] = params[param_name]
  427. elif param_in == "query":
  428. query_params[param_name] = params[param_name]
  429. final_url = f"{url}{route_path}"
  430. for key, value in path_params.items():
  431. final_url = final_url.replace(f"{{{key}}}", str(value))
  432. if query_params:
  433. query_string = "&".join(f"{k}={v}" for k, v in query_params.items())
  434. final_url = f"{final_url}?{query_string}"
  435. if operation.get("requestBody", {}).get("content"):
  436. if params:
  437. body_params = params
  438. else:
  439. raise Exception(
  440. f"Request body expected for operation '{name}' but none found."
  441. )
  442. headers = {"Content-Type": "application/json"}
  443. if token:
  444. headers["Authorization"] = f"Bearer {token}"
  445. async with aiohttp.ClientSession() as session:
  446. request_method = getattr(session, http_method.lower())
  447. if http_method in ["post", "put", "patch"]:
  448. async with request_method(
  449. final_url, json=body_params, headers=headers
  450. ) as response:
  451. if response.status >= 400:
  452. text = await response.text()
  453. raise Exception(f"HTTP error {response.status}: {text}")
  454. return await response.json()
  455. else:
  456. async with request_method(final_url, headers=headers) as response:
  457. if response.status >= 400:
  458. text = await response.text()
  459. raise Exception(f"HTTP error {response.status}: {text}")
  460. return await response.json()
  461. except Exception as err:
  462. error = str(err)
  463. print("API Request Error:", error)
  464. return {"error": error}