tools.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537
  1. import inspect
  2. import logging
  3. import re
  4. import inspect
  5. import aiohttp
  6. import asyncio
  7. from typing import Any, Awaitable, Callable, get_type_hints, Dict, List, Union, Optional
  8. from functools import update_wrapper, partial
  9. from fastapi import Request
  10. from pydantic import BaseModel, Field, create_model
  11. from langchain_core.utils.function_calling import convert_to_openai_function
  12. from open_webui.models.tools import Tools
  13. from open_webui.models.users import UserModel
  14. from open_webui.utils.plugin import load_tools_module_by_id
  15. import copy
  16. log = logging.getLogger(__name__)
  17. def get_async_tool_function_and_apply_extra_params(
  18. function: Callable, extra_params: dict
  19. ) -> Callable[..., Awaitable]:
  20. sig = inspect.signature(function)
  21. extra_params = {k: v for k, v in extra_params.items() if k in sig.parameters}
  22. partial_func = partial(function, **extra_params)
  23. if inspect.iscoroutinefunction(function):
  24. update_wrapper(partial_func, function)
  25. return partial_func
  26. else:
  27. # Make it a coroutine function
  28. async def new_function(*args, **kwargs):
  29. return partial_func(*args, **kwargs)
  30. update_wrapper(new_function, function)
  31. return new_function
  32. def get_tools(
  33. request: Request, tool_ids: list[str], user: UserModel, extra_params: dict
  34. ) -> dict[str, dict]:
  35. tools_dict = {}
  36. for tool_id in tool_ids:
  37. tool = Tools.get_tool_by_id(tool_id)
  38. if tool is None:
  39. if tool_id.startswith("server:"):
  40. server_idx = int(tool_id.split(":")[1])
  41. tool_server_connection = (
  42. request.app.state.config.TOOL_SERVER_CONNECTIONS[server_idx]
  43. )
  44. tool_server_data = request.app.state.TOOL_SERVERS[server_idx]
  45. specs = tool_server_data.get("specs", [])
  46. for spec in specs:
  47. function_name = spec["name"]
  48. auth_type = tool_server_connection.get("auth_type", "bearer")
  49. token = None
  50. if auth_type == "bearer":
  51. token = tool_server_connection.get("key", "")
  52. elif auth_type == "session":
  53. token = request.state.token.credentials
  54. callable = get_async_tool_function_and_apply_extra_params(
  55. execute_tool_server,
  56. {
  57. "token": token,
  58. "url": tool_server_data["url"],
  59. "name": function_name,
  60. "server_data": tool_server_data,
  61. },
  62. )
  63. tool_dict = {
  64. "tool_id": tool_id,
  65. "callable": callable,
  66. "spec": spec,
  67. }
  68. # TODO: if collision, prepend toolkit name
  69. if function_name in tools_dict:
  70. log.warning(
  71. f"Tool {function_name} already exists in another tools!"
  72. )
  73. log.warning(f"Discarding {tool_id}.{function_name}")
  74. else:
  75. tools_dict[function_name] = tool_dict
  76. else:
  77. continue
  78. else:
  79. module = request.app.state.TOOLS.get(tool_id, None)
  80. if module is None:
  81. module, _ = load_tools_module_by_id(tool_id)
  82. request.app.state.TOOLS[tool_id] = module
  83. extra_params["__id__"] = tool_id
  84. # Set valves for the tool
  85. if hasattr(module, "valves") and hasattr(module, "Valves"):
  86. valves = Tools.get_tool_valves_by_id(tool_id) or {}
  87. module.valves = module.Valves(**valves)
  88. if hasattr(module, "UserValves"):
  89. extra_params["__user__"]["valves"] = module.UserValves( # type: ignore
  90. **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
  91. )
  92. for spec in tool.specs:
  93. # TODO: Fix hack for OpenAI API
  94. # Some times breaks OpenAI but others don't. Leaving the comment
  95. for val in spec.get("parameters", {}).get("properties", {}).values():
  96. if val["type"] == "str":
  97. val["type"] = "string"
  98. # Remove internal reserved parameters (e.g. __id__, __user__)
  99. spec["parameters"]["properties"] = {
  100. key: val
  101. for key, val in spec["parameters"]["properties"].items()
  102. if not key.startswith("__")
  103. }
  104. # convert to function that takes only model params and inserts custom params
  105. function_name = spec["name"]
  106. tool_function = getattr(module, function_name)
  107. callable = get_async_tool_function_and_apply_extra_params(
  108. tool_function, extra_params
  109. )
  110. # TODO: Support Pydantic models as parameters
  111. if callable.__doc__ and callable.__doc__.strip() != "":
  112. s = re.split(":(param|return)", callable.__doc__, 1)
  113. spec["description"] = s[0]
  114. else:
  115. spec["description"] = function_name
  116. tool_dict = {
  117. "tool_id": tool_id,
  118. "callable": callable,
  119. "spec": spec,
  120. # Misc info
  121. "metadata": {
  122. "file_handler": hasattr(module, "file_handler")
  123. and module.file_handler,
  124. "citation": hasattr(module, "citation") and module.citation,
  125. },
  126. }
  127. # TODO: if collision, prepend toolkit name
  128. if function_name in tools_dict:
  129. log.warning(
  130. f"Tool {function_name} already exists in another tools!"
  131. )
  132. log.warning(f"Discarding {tool_id}.{function_name}")
  133. else:
  134. tools_dict[function_name] = tool_dict
  135. return tools_dict
  136. def parse_description(docstring: str | None) -> str:
  137. """
  138. Parse a function's docstring to extract the description.
  139. Args:
  140. docstring (str): The docstring to parse.
  141. Returns:
  142. str: The description.
  143. """
  144. if not docstring:
  145. return ""
  146. lines = [line.strip() for line in docstring.strip().split("\n")]
  147. description_lines: list[str] = []
  148. for line in lines:
  149. if re.match(r":param", line) or re.match(r":return", line):
  150. break
  151. description_lines.append(line)
  152. return "\n".join(description_lines)
  153. def parse_docstring(docstring):
  154. """
  155. Parse a function's docstring to extract parameter descriptions in reST format.
  156. Args:
  157. docstring (str): The docstring to parse.
  158. Returns:
  159. dict: A dictionary where keys are parameter names and values are descriptions.
  160. """
  161. if not docstring:
  162. return {}
  163. # Regex to match `:param name: description` format
  164. param_pattern = re.compile(r":param (\w+):\s*(.+)")
  165. param_descriptions = {}
  166. for line in docstring.splitlines():
  167. match = param_pattern.match(line.strip())
  168. if not match:
  169. continue
  170. param_name, param_description = match.groups()
  171. if param_name.startswith("__"):
  172. continue
  173. param_descriptions[param_name] = param_description
  174. return param_descriptions
  175. def function_to_pydantic_model(func: Callable) -> type[BaseModel]:
  176. """
  177. Converts a Python function's type hints and docstring to a Pydantic model,
  178. including support for nested types, default values, and descriptions.
  179. Args:
  180. func: The function whose type hints and docstring should be converted.
  181. model_name: The name of the generated Pydantic model.
  182. Returns:
  183. A Pydantic model class.
  184. """
  185. type_hints = get_type_hints(func)
  186. signature = inspect.signature(func)
  187. parameters = signature.parameters
  188. docstring = func.__doc__
  189. descriptions = parse_docstring(docstring)
  190. tool_description = parse_description(docstring)
  191. field_defs = {}
  192. for name, param in parameters.items():
  193. type_hint = type_hints.get(name, Any)
  194. default_value = param.default if param.default is not param.empty else ...
  195. description = descriptions.get(name, None)
  196. if not description:
  197. field_defs[name] = type_hint, default_value
  198. continue
  199. field_defs[name] = type_hint, Field(default_value, description=description)
  200. model = create_model(func.__name__, **field_defs)
  201. model.__doc__ = tool_description
  202. return model
  203. def get_callable_attributes(tool: object) -> list[Callable]:
  204. return [
  205. getattr(tool, func)
  206. for func in dir(tool)
  207. if callable(getattr(tool, func))
  208. and not func.startswith("__")
  209. and not inspect.isclass(getattr(tool, func))
  210. ]
  211. def get_tools_specs(tool_class: object) -> list[dict]:
  212. function_model_list = map(
  213. function_to_pydantic_model, get_callable_attributes(tool_class)
  214. )
  215. return [
  216. convert_to_openai_function(function_model)
  217. for function_model in function_model_list
  218. ]
  219. def resolve_schema(schema, components):
  220. """
  221. Recursively resolves a JSON schema using OpenAPI components.
  222. """
  223. if not schema:
  224. return {}
  225. if "$ref" in schema:
  226. ref_path = schema["$ref"]
  227. ref_parts = ref_path.strip("#/").split("/")
  228. resolved = components
  229. for part in ref_parts[1:]: # Skip the initial 'components'
  230. resolved = resolved.get(part, {})
  231. return resolve_schema(resolved, components)
  232. resolved_schema = copy.deepcopy(schema)
  233. # Recursively resolve inner schemas
  234. if "properties" in resolved_schema:
  235. for prop, prop_schema in resolved_schema["properties"].items():
  236. resolved_schema["properties"][prop] = resolve_schema(
  237. prop_schema, components
  238. )
  239. if "items" in resolved_schema:
  240. resolved_schema["items"] = resolve_schema(resolved_schema["items"], components)
  241. return resolved_schema
  242. def convert_openapi_to_tool_payload(openapi_spec):
  243. """
  244. Converts an OpenAPI specification into a custom tool payload structure.
  245. Args:
  246. openapi_spec (dict): The OpenAPI specification as a Python dict.
  247. Returns:
  248. list: A list of tool payloads.
  249. """
  250. tool_payload = []
  251. for path, methods in openapi_spec.get("paths", {}).items():
  252. for method, operation in methods.items():
  253. tool = {
  254. "type": "function",
  255. "name": operation.get("operationId"),
  256. "description": operation.get("summary", "No description available."),
  257. "parameters": {"type": "object", "properties": {}, "required": []},
  258. }
  259. # Extract path and query parameters
  260. for param in operation.get("parameters", []):
  261. param_name = param["name"]
  262. param_schema = param.get("schema", {})
  263. tool["parameters"]["properties"][param_name] = {
  264. "type": param_schema.get("type"),
  265. "description": param_schema.get("description", ""),
  266. }
  267. if param.get("required"):
  268. tool["parameters"]["required"].append(param_name)
  269. # Extract and resolve requestBody if available
  270. request_body = operation.get("requestBody")
  271. if request_body:
  272. content = request_body.get("content", {})
  273. json_schema = content.get("application/json", {}).get("schema")
  274. if json_schema:
  275. resolved_schema = resolve_schema(
  276. json_schema, openapi_spec.get("components", {})
  277. )
  278. if resolved_schema.get("properties"):
  279. tool["parameters"]["properties"].update(
  280. resolved_schema["properties"]
  281. )
  282. if "required" in resolved_schema:
  283. tool["parameters"]["required"] = list(
  284. set(
  285. tool["parameters"]["required"]
  286. + resolved_schema["required"]
  287. )
  288. )
  289. elif resolved_schema.get("type") == "array":
  290. tool["parameters"] = resolved_schema # special case for array
  291. tool_payload.append(tool)
  292. return tool_payload
  293. async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]:
  294. headers = {
  295. "Accept": "application/json",
  296. "Content-Type": "application/json",
  297. }
  298. if token:
  299. headers["Authorization"] = f"Bearer {token}"
  300. error = None
  301. try:
  302. async with aiohttp.ClientSession() as session:
  303. async with session.get(url, headers=headers) as response:
  304. if response.status != 200:
  305. error_body = await response.json()
  306. raise Exception(error_body)
  307. res = await response.json()
  308. except Exception as err:
  309. print("Error:", err)
  310. if isinstance(err, dict) and "detail" in err:
  311. error = err["detail"]
  312. else:
  313. error = str(err)
  314. raise Exception(error)
  315. data = {
  316. "openapi": res,
  317. "info": res.get("info", {}),
  318. "specs": convert_openapi_to_tool_payload(res),
  319. }
  320. print("Fetched data:", data)
  321. return data
  322. async def get_tool_servers_data(
  323. servers: List[Dict[str, Any]], session_token: Optional[str] = None
  324. ) -> List[Dict[str, Any]]:
  325. # Prepare list of enabled servers along with their original index
  326. server_entries = []
  327. for idx, server in enumerate(servers):
  328. if server.get("config", {}).get("enable"):
  329. url_path = server.get("path", "openapi.json")
  330. full_url = f"{server.get('url')}/{url_path}"
  331. auth_type = server.get("auth_type", "bearer")
  332. token = None
  333. if auth_type == "bearer":
  334. token = server.get("key", "")
  335. elif auth_type == "session":
  336. token = session_token
  337. server_entries.append((idx, server, full_url, token))
  338. # Create async tasks to fetch data
  339. tasks = [get_tool_server_data(token, url) for (_, _, url, token) in server_entries]
  340. # Execute tasks concurrently
  341. responses = await asyncio.gather(*tasks, return_exceptions=True)
  342. # Build final results with index and server metadata
  343. results = []
  344. for (idx, server, url, _), response in zip(server_entries, responses):
  345. if isinstance(response, Exception):
  346. print(f"Failed to connect to {url} OpenAPI tool server")
  347. continue
  348. results.append(
  349. {
  350. "idx": idx,
  351. "url": server.get("url"),
  352. "openapi": response.get("openapi"),
  353. "info": response.get("info"),
  354. "specs": response.get("specs"),
  355. }
  356. )
  357. return results
  358. async def execute_tool_server(
  359. token: str, url: str, name: str, params: Dict[str, Any], server_data: Dict[str, Any]
  360. ) -> Any:
  361. error = None
  362. try:
  363. openapi = server_data.get("openapi", {})
  364. paths = openapi.get("paths", {})
  365. matching_route = None
  366. for route_path, methods in paths.items():
  367. for http_method, operation in methods.items():
  368. if isinstance(operation, dict) and operation.get("operationId") == name:
  369. matching_route = (route_path, methods)
  370. break
  371. if matching_route:
  372. break
  373. if not matching_route:
  374. raise Exception(f"No matching route found for operationId: {name}")
  375. route_path, methods = matching_route
  376. method_entry = None
  377. for http_method, operation in methods.items():
  378. if operation.get("operationId") == name:
  379. method_entry = (http_method.lower(), operation)
  380. break
  381. if not method_entry:
  382. raise Exception(f"No matching method found for operationId: {name}")
  383. http_method, operation = method_entry
  384. path_params = {}
  385. query_params = {}
  386. body_params = {}
  387. for param in operation.get("parameters", []):
  388. param_name = param["name"]
  389. param_in = param["in"]
  390. if param_name in params:
  391. if param_in == "path":
  392. path_params[param_name] = params[param_name]
  393. elif param_in == "query":
  394. query_params[param_name] = params[param_name]
  395. final_url = f"{url}{route_path}"
  396. for key, value in path_params.items():
  397. final_url = final_url.replace(f"{{{key}}}", str(value))
  398. if query_params:
  399. query_string = "&".join(f"{k}={v}" for k, v in query_params.items())
  400. final_url = f"{final_url}?{query_string}"
  401. if operation.get("requestBody", {}).get("content"):
  402. if params:
  403. body_params = params
  404. else:
  405. raise Exception(
  406. f"Request body expected for operation '{name}' but none found."
  407. )
  408. headers = {"Content-Type": "application/json"}
  409. if token:
  410. headers["Authorization"] = f"Bearer {token}"
  411. async with aiohttp.ClientSession() as session:
  412. request_method = getattr(session, http_method.lower())
  413. if http_method in ["post", "put", "patch"]:
  414. async with request_method(
  415. final_url, json=body_params, headers=headers
  416. ) as response:
  417. if response.status >= 400:
  418. text = await response.text()
  419. raise Exception(f"HTTP error {response.status}: {text}")
  420. return await response.json()
  421. else:
  422. async with request_method(final_url, headers=headers) as response:
  423. if response.status >= 400:
  424. text = await response.text()
  425. raise Exception(f"HTTP error {response.status}: {text}")
  426. return await response.json()
  427. except Exception as err:
  428. error = str(err)
  429. print("API Request Error:", error)
  430. return {"error": error}