tools.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573
  1. import inspect
  2. import logging
  3. import re
  4. import inspect
  5. import aiohttp
  6. import asyncio
  7. import yaml
  8. from typing import Any, Awaitable, Callable, get_type_hints, Dict, List, Union, Optional
  9. from functools import update_wrapper, partial
  10. from fastapi import Request
  11. from pydantic import BaseModel, Field, create_model
  12. from langchain_core.utils.function_calling import convert_to_openai_function
  13. from open_webui.models.tools import Tools
  14. from open_webui.models.users import UserModel
  15. from open_webui.utils.plugin import load_tool_module_by_id
  16. from open_webui.env import AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA
  17. import copy
  18. log = logging.getLogger(__name__)
  19. def get_async_tool_function_and_apply_extra_params(
  20. function: Callable, extra_params: dict
  21. ) -> Callable[..., Awaitable]:
  22. sig = inspect.signature(function)
  23. extra_params = {k: v for k, v in extra_params.items() if k in sig.parameters}
  24. partial_func = partial(function, **extra_params)
  25. if inspect.iscoroutinefunction(function):
  26. update_wrapper(partial_func, function)
  27. return partial_func
  28. else:
  29. # Make it a coroutine function
  30. async def new_function(*args, **kwargs):
  31. return partial_func(*args, **kwargs)
  32. update_wrapper(new_function, function)
  33. return new_function
  34. def get_tools(
  35. request: Request, tool_ids: list[str], user: UserModel, extra_params: dict
  36. ) -> dict[str, dict]:
  37. tools_dict = {}
  38. for tool_id in tool_ids:
  39. tool = Tools.get_tool_by_id(tool_id)
  40. if tool is None:
  41. if tool_id.startswith("server:"):
  42. server_idx = int(tool_id.split(":")[1])
  43. tool_server_connection = (
  44. request.app.state.config.TOOL_SERVER_CONNECTIONS[server_idx]
  45. )
  46. tool_server_data = None
  47. for server in request.app.state.TOOL_SERVERS:
  48. if server["idx"] == server_idx:
  49. tool_server_data = server
  50. break
  51. assert tool_server_data is not None
  52. specs = tool_server_data.get("specs", [])
  53. for spec in specs:
  54. function_name = spec["name"]
  55. auth_type = tool_server_connection.get("auth_type", "bearer")
  56. token = None
  57. if auth_type == "bearer":
  58. token = tool_server_connection.get("key", "")
  59. elif auth_type == "session":
  60. token = request.state.token.credentials
  61. def make_tool_function(function_name, token, tool_server_data):
  62. async def tool_function(**kwargs):
  63. print(
  64. f"Executing tool function {function_name} with params: {kwargs}"
  65. )
  66. return await execute_tool_server(
  67. token=token,
  68. url=tool_server_data["url"],
  69. name=function_name,
  70. params=kwargs,
  71. server_data=tool_server_data,
  72. )
  73. return tool_function
  74. tool_function = make_tool_function(
  75. function_name, token, tool_server_data
  76. )
  77. callable = get_async_tool_function_and_apply_extra_params(
  78. tool_function,
  79. {},
  80. )
  81. tool_dict = {
  82. "tool_id": tool_id,
  83. "callable": callable,
  84. "spec": spec,
  85. }
  86. # TODO: if collision, prepend toolkit name
  87. if function_name in tools_dict:
  88. log.warning(
  89. f"Tool {function_name} already exists in another tools!"
  90. )
  91. log.warning(f"Discarding {tool_id}.{function_name}")
  92. else:
  93. tools_dict[function_name] = tool_dict
  94. else:
  95. continue
  96. else:
  97. module = request.app.state.TOOLS.get(tool_id, None)
  98. if module is None:
  99. module, _ = load_tool_module_by_id(tool_id)
  100. request.app.state.TOOLS[tool_id] = module
  101. extra_params["__id__"] = tool_id
  102. # Set valves for the tool
  103. if hasattr(module, "valves") and hasattr(module, "Valves"):
  104. valves = Tools.get_tool_valves_by_id(tool_id) or {}
  105. module.valves = module.Valves(**valves)
  106. if hasattr(module, "UserValves"):
  107. extra_params["__user__"]["valves"] = module.UserValves( # type: ignore
  108. **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
  109. )
  110. for spec in tool.specs:
  111. # TODO: Fix hack for OpenAI API
  112. # Some times breaks OpenAI but others don't. Leaving the comment
  113. for val in spec.get("parameters", {}).get("properties", {}).values():
  114. if val["type"] == "str":
  115. val["type"] = "string"
  116. # Remove internal reserved parameters (e.g. __id__, __user__)
  117. spec["parameters"]["properties"] = {
  118. key: val
  119. for key, val in spec["parameters"]["properties"].items()
  120. if not key.startswith("__")
  121. }
  122. # convert to function that takes only model params and inserts custom params
  123. function_name = spec["name"]
  124. tool_function = getattr(module, function_name)
  125. callable = get_async_tool_function_and_apply_extra_params(
  126. tool_function, extra_params
  127. )
  128. # TODO: Support Pydantic models as parameters
  129. if callable.__doc__ and callable.__doc__.strip() != "":
  130. s = re.split(":(param|return)", callable.__doc__, 1)
  131. spec["description"] = s[0]
  132. else:
  133. spec["description"] = function_name
  134. tool_dict = {
  135. "tool_id": tool_id,
  136. "callable": callable,
  137. "spec": spec,
  138. # Misc info
  139. "metadata": {
  140. "file_handler": hasattr(module, "file_handler")
  141. and module.file_handler,
  142. "citation": hasattr(module, "citation") and module.citation,
  143. },
  144. }
  145. # TODO: if collision, prepend toolkit name
  146. if function_name in tools_dict:
  147. log.warning(
  148. f"Tool {function_name} already exists in another tools!"
  149. )
  150. log.warning(f"Discarding {tool_id}.{function_name}")
  151. else:
  152. tools_dict[function_name] = tool_dict
  153. return tools_dict
  154. def parse_description(docstring: str | None) -> str:
  155. """
  156. Parse a function's docstring to extract the description.
  157. Args:
  158. docstring (str): The docstring to parse.
  159. Returns:
  160. str: The description.
  161. """
  162. if not docstring:
  163. return ""
  164. lines = [line.strip() for line in docstring.strip().split("\n")]
  165. description_lines: list[str] = []
  166. for line in lines:
  167. if re.match(r":param", line) or re.match(r":return", line):
  168. break
  169. description_lines.append(line)
  170. return "\n".join(description_lines)
  171. def parse_docstring(docstring):
  172. """
  173. Parse a function's docstring to extract parameter descriptions in reST format.
  174. Args:
  175. docstring (str): The docstring to parse.
  176. Returns:
  177. dict: A dictionary where keys are parameter names and values are descriptions.
  178. """
  179. if not docstring:
  180. return {}
  181. # Regex to match `:param name: description` format
  182. param_pattern = re.compile(r":param (\w+):\s*(.+)")
  183. param_descriptions = {}
  184. for line in docstring.splitlines():
  185. match = param_pattern.match(line.strip())
  186. if not match:
  187. continue
  188. param_name, param_description = match.groups()
  189. if param_name.startswith("__"):
  190. continue
  191. param_descriptions[param_name] = param_description
  192. return param_descriptions
  193. def function_to_pydantic_model(func: Callable) -> type[BaseModel]:
  194. """
  195. Converts a Python function's type hints and docstring to a Pydantic model,
  196. including support for nested types, default values, and descriptions.
  197. Args:
  198. func: The function whose type hints and docstring should be converted.
  199. model_name: The name of the generated Pydantic model.
  200. Returns:
  201. A Pydantic model class.
  202. """
  203. type_hints = get_type_hints(func)
  204. signature = inspect.signature(func)
  205. parameters = signature.parameters
  206. docstring = func.__doc__
  207. descriptions = parse_docstring(docstring)
  208. tool_description = parse_description(docstring)
  209. field_defs = {}
  210. for name, param in parameters.items():
  211. type_hint = type_hints.get(name, Any)
  212. default_value = param.default if param.default is not param.empty else ...
  213. description = descriptions.get(name, None)
  214. if not description:
  215. field_defs[name] = type_hint, default_value
  216. continue
  217. field_defs[name] = type_hint, Field(default_value, description=description)
  218. model = create_model(func.__name__, **field_defs)
  219. model.__doc__ = tool_description
  220. return model
  221. def get_functions_from_tool(tool: object) -> list[Callable]:
  222. return [
  223. getattr(tool, func)
  224. for func in dir(tool)
  225. if callable(
  226. getattr(tool, func)
  227. ) # checks if the attribute is callable (a method or function).
  228. and not func.startswith(
  229. "__"
  230. ) # filters out special (dunder) methods like init, str, etc. — these are usually built-in functions of an object that you might not need to use directly.
  231. and not inspect.isclass(
  232. getattr(tool, func)
  233. ) # ensures that the callable is not a class itself, just a method or function.
  234. ]
  235. def get_tool_specs(tool_module: object) -> list[dict]:
  236. function_models = map(
  237. function_to_pydantic_model, get_functions_from_tool(tool_module)
  238. )
  239. return [
  240. convert_to_openai_function(function_model) for function_model in function_models
  241. ]
  242. def resolve_schema(schema, components):
  243. """
  244. Recursively resolves a JSON schema using OpenAPI components.
  245. """
  246. if not schema:
  247. return {}
  248. if "$ref" in schema:
  249. ref_path = schema["$ref"]
  250. ref_parts = ref_path.strip("#/").split("/")
  251. resolved = components
  252. for part in ref_parts[1:]: # Skip the initial 'components'
  253. resolved = resolved.get(part, {})
  254. return resolve_schema(resolved, components)
  255. resolved_schema = copy.deepcopy(schema)
  256. # Recursively resolve inner schemas
  257. if "properties" in resolved_schema:
  258. for prop, prop_schema in resolved_schema["properties"].items():
  259. resolved_schema["properties"][prop] = resolve_schema(
  260. prop_schema, components
  261. )
  262. if "items" in resolved_schema:
  263. resolved_schema["items"] = resolve_schema(resolved_schema["items"], components)
  264. return resolved_schema
  265. def convert_openapi_to_tool_payload(openapi_spec):
  266. """
  267. Converts an OpenAPI specification into a custom tool payload structure.
  268. Args:
  269. openapi_spec (dict): The OpenAPI specification as a Python dict.
  270. Returns:
  271. list: A list of tool payloads.
  272. """
  273. tool_payload = []
  274. for path, methods in openapi_spec.get("paths", {}).items():
  275. for method, operation in methods.items():
  276. tool = {
  277. "type": "function",
  278. "name": operation.get("operationId"),
  279. "description": operation.get(
  280. "description", operation.get("summary", "No description available.")
  281. ),
  282. "parameters": {"type": "object", "properties": {}, "required": []},
  283. }
  284. # Extract path and query parameters
  285. for param in operation.get("parameters", []):
  286. param_name = param["name"]
  287. param_schema = param.get("schema", {})
  288. tool["parameters"]["properties"][param_name] = {
  289. "type": param_schema.get("type"),
  290. "description": param_schema.get("description", ""),
  291. }
  292. if param.get("required"):
  293. tool["parameters"]["required"].append(param_name)
  294. # Extract and resolve requestBody if available
  295. request_body = operation.get("requestBody")
  296. if request_body:
  297. content = request_body.get("content", {})
  298. json_schema = content.get("application/json", {}).get("schema")
  299. if json_schema:
  300. resolved_schema = resolve_schema(
  301. json_schema, openapi_spec.get("components", {})
  302. )
  303. if resolved_schema.get("properties"):
  304. tool["parameters"]["properties"].update(
  305. resolved_schema["properties"]
  306. )
  307. if "required" in resolved_schema:
  308. tool["parameters"]["required"] = list(
  309. set(
  310. tool["parameters"]["required"]
  311. + resolved_schema["required"]
  312. )
  313. )
  314. elif resolved_schema.get("type") == "array":
  315. tool["parameters"] = resolved_schema # special case for array
  316. tool_payload.append(tool)
  317. return tool_payload
  318. async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]:
  319. headers = {
  320. "Accept": "application/json",
  321. "Content-Type": "application/json",
  322. }
  323. if token:
  324. headers["Authorization"] = f"Bearer {token}"
  325. error = None
  326. try:
  327. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA)
  328. async with aiohttp.ClientSession(timeout=timeout) as session:
  329. async with session.get(url, headers=headers) as response:
  330. if response.status != 200:
  331. error_body = await response.json()
  332. raise Exception(error_body)
  333. # Check if URL ends with .yaml or .yml to determine format
  334. if url.lower().endswith((".yaml", ".yml")):
  335. text_content = await response.text()
  336. res = yaml.safe_load(text_content)
  337. else:
  338. res = await response.json()
  339. except Exception as err:
  340. log.exception(f"Could not fetch tool server spec from {url}")
  341. if isinstance(err, dict) and "detail" in err:
  342. error = err["detail"]
  343. else:
  344. error = str(err)
  345. raise Exception(error)
  346. data = {
  347. "openapi": res,
  348. "info": res.get("info", {}),
  349. "specs": convert_openapi_to_tool_payload(res),
  350. }
  351. print("Fetched data:", data)
  352. return data
  353. async def get_tool_servers_data(
  354. servers: List[Dict[str, Any]], session_token: Optional[str] = None
  355. ) -> List[Dict[str, Any]]:
  356. # Prepare list of enabled servers along with their original index
  357. server_entries = []
  358. for idx, server in enumerate(servers):
  359. if server.get("config", {}).get("enable"):
  360. url_path = server.get("path", "openapi.json")
  361. full_url = f"{server.get('url')}/{url_path}"
  362. auth_type = server.get("auth_type", "bearer")
  363. token = None
  364. if auth_type == "bearer":
  365. token = server.get("key", "")
  366. elif auth_type == "session":
  367. token = session_token
  368. server_entries.append((idx, server, full_url, token))
  369. # Create async tasks to fetch data
  370. tasks = [get_tool_server_data(token, url) for (_, _, url, token) in server_entries]
  371. # Execute tasks concurrently
  372. responses = await asyncio.gather(*tasks, return_exceptions=True)
  373. # Build final results with index and server metadata
  374. results = []
  375. for (idx, server, url, _), response in zip(server_entries, responses):
  376. if isinstance(response, Exception):
  377. print(f"Failed to connect to {url} OpenAPI tool server")
  378. continue
  379. results.append(
  380. {
  381. "idx": idx,
  382. "url": server.get("url"),
  383. "openapi": response.get("openapi"),
  384. "info": response.get("info"),
  385. "specs": response.get("specs"),
  386. }
  387. )
  388. return results
  389. async def execute_tool_server(
  390. token: str, url: str, name: str, params: Dict[str, Any], server_data: Dict[str, Any]
  391. ) -> Any:
  392. error = None
  393. try:
  394. openapi = server_data.get("openapi", {})
  395. paths = openapi.get("paths", {})
  396. matching_route = None
  397. for route_path, methods in paths.items():
  398. for http_method, operation in methods.items():
  399. if isinstance(operation, dict) and operation.get("operationId") == name:
  400. matching_route = (route_path, methods)
  401. break
  402. if matching_route:
  403. break
  404. if not matching_route:
  405. raise Exception(f"No matching route found for operationId: {name}")
  406. route_path, methods = matching_route
  407. method_entry = None
  408. for http_method, operation in methods.items():
  409. if operation.get("operationId") == name:
  410. method_entry = (http_method.lower(), operation)
  411. break
  412. if not method_entry:
  413. raise Exception(f"No matching method found for operationId: {name}")
  414. http_method, operation = method_entry
  415. path_params = {}
  416. query_params = {}
  417. body_params = {}
  418. for param in operation.get("parameters", []):
  419. param_name = param["name"]
  420. param_in = param["in"]
  421. if param_name in params:
  422. if param_in == "path":
  423. path_params[param_name] = params[param_name]
  424. elif param_in == "query":
  425. query_params[param_name] = params[param_name]
  426. final_url = f"{url}{route_path}"
  427. for key, value in path_params.items():
  428. final_url = final_url.replace(f"{{{key}}}", str(value))
  429. if query_params:
  430. query_string = "&".join(f"{k}={v}" for k, v in query_params.items())
  431. final_url = f"{final_url}?{query_string}"
  432. if operation.get("requestBody", {}).get("content"):
  433. if params:
  434. body_params = params
  435. else:
  436. raise Exception(
  437. f"Request body expected for operation '{name}' but none found."
  438. )
  439. headers = {"Content-Type": "application/json"}
  440. if token:
  441. headers["Authorization"] = f"Bearer {token}"
  442. async with aiohttp.ClientSession() as session:
  443. request_method = getattr(session, http_method.lower())
  444. if http_method in ["post", "put", "patch"]:
  445. async with request_method(
  446. final_url, json=body_params, headers=headers
  447. ) as response:
  448. if response.status >= 400:
  449. text = await response.text()
  450. raise Exception(f"HTTP error {response.status}: {text}")
  451. return await response.json()
  452. else:
  453. async with request_method(final_url, headers=headers) as response:
  454. if response.status >= 400:
  455. text = await response.text()
  456. raise Exception(f"HTTP error {response.status}: {text}")
  457. return await response.json()
  458. except Exception as err:
  459. error = str(err)
  460. print("API Request Error:", error)
  461. return {"error": error}