tools.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597
  1. import inspect
  2. import logging
  3. import re
  4. import inspect
  5. import aiohttp
  6. import asyncio
  7. import yaml
  8. from pydantic import BaseModel
  9. from pydantic.fields import FieldInfo
  10. from typing import (
  11. Any,
  12. Awaitable,
  13. Callable,
  14. get_type_hints,
  15. get_args,
  16. get_origin,
  17. Dict,
  18. List,
  19. Tuple,
  20. Union,
  21. Optional,
  22. Type,
  23. )
  24. from functools import update_wrapper, partial
  25. from fastapi import Request
  26. from pydantic import BaseModel, Field, create_model
  27. from langchain_core.utils.function_calling import (
  28. convert_to_openai_function as convert_pydantic_model_to_openai_function_spec,
  29. )
  30. from open_webui.models.tools import Tools
  31. from open_webui.models.users import UserModel
  32. from open_webui.utils.plugin import load_tool_module_by_id
  33. from open_webui.env import AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA
  34. import copy
  35. log = logging.getLogger(__name__)
  36. def get_async_tool_function_and_apply_extra_params(
  37. function: Callable, extra_params: dict
  38. ) -> Callable[..., Awaitable]:
  39. sig = inspect.signature(function)
  40. extra_params = {k: v for k, v in extra_params.items() if k in sig.parameters}
  41. partial_func = partial(function, **extra_params)
  42. if inspect.iscoroutinefunction(function):
  43. update_wrapper(partial_func, function)
  44. return partial_func
  45. else:
  46. # Make it a coroutine function
  47. async def new_function(*args, **kwargs):
  48. return partial_func(*args, **kwargs)
  49. update_wrapper(new_function, function)
  50. return new_function
  51. def get_tools(
  52. request: Request, tool_ids: list[str], user: UserModel, extra_params: dict
  53. ) -> dict[str, dict]:
  54. tools_dict = {}
  55. for tool_id in tool_ids:
  56. tool = Tools.get_tool_by_id(tool_id)
  57. if tool is None:
  58. if tool_id.startswith("server:"):
  59. server_idx = int(tool_id.split(":")[1])
  60. tool_server_connection = (
  61. request.app.state.config.TOOL_SERVER_CONNECTIONS[server_idx]
  62. )
  63. tool_server_data = None
  64. for server in request.app.state.TOOL_SERVERS:
  65. if server["idx"] == server_idx:
  66. tool_server_data = server
  67. break
  68. assert tool_server_data is not None
  69. specs = tool_server_data.get("specs", [])
  70. for spec in specs:
  71. function_name = spec["name"]
  72. auth_type = tool_server_connection.get("auth_type", "bearer")
  73. token = None
  74. if auth_type == "bearer":
  75. token = tool_server_connection.get("key", "")
  76. elif auth_type == "session":
  77. token = request.state.token.credentials
  78. def make_tool_function(function_name, token, tool_server_data):
  79. async def tool_function(**kwargs):
  80. print(
  81. f"Executing tool function {function_name} with params: {kwargs}"
  82. )
  83. return await execute_tool_server(
  84. token=token,
  85. url=tool_server_data["url"],
  86. name=function_name,
  87. params=kwargs,
  88. server_data=tool_server_data,
  89. )
  90. return tool_function
  91. tool_function = make_tool_function(
  92. function_name, token, tool_server_data
  93. )
  94. callable = get_async_tool_function_and_apply_extra_params(
  95. tool_function,
  96. {},
  97. )
  98. tool_dict = {
  99. "tool_id": tool_id,
  100. "callable": callable,
  101. "spec": spec,
  102. }
  103. # TODO: if collision, prepend toolkit name
  104. if function_name in tools_dict:
  105. log.warning(
  106. f"Tool {function_name} already exists in another tools!"
  107. )
  108. log.warning(f"Discarding {tool_id}.{function_name}")
  109. else:
  110. tools_dict[function_name] = tool_dict
  111. else:
  112. continue
  113. else:
  114. module = request.app.state.TOOLS.get(tool_id, None)
  115. if module is None:
  116. module, _ = load_tool_module_by_id(tool_id)
  117. request.app.state.TOOLS[tool_id] = module
  118. extra_params["__id__"] = tool_id
  119. # Set valves for the tool
  120. if hasattr(module, "valves") and hasattr(module, "Valves"):
  121. valves = Tools.get_tool_valves_by_id(tool_id) or {}
  122. module.valves = module.Valves(**valves)
  123. if hasattr(module, "UserValves"):
  124. extra_params["__user__"]["valves"] = module.UserValves( # type: ignore
  125. **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
  126. )
  127. for spec in tool.specs:
  128. # TODO: Fix hack for OpenAI API
  129. # Some times breaks OpenAI but others don't. Leaving the comment
  130. for val in spec.get("parameters", {}).get("properties", {}).values():
  131. if val["type"] == "str":
  132. val["type"] = "string"
  133. # Remove internal reserved parameters (e.g. __id__, __user__)
  134. spec["parameters"]["properties"] = {
  135. key: val
  136. for key, val in spec["parameters"]["properties"].items()
  137. if not key.startswith("__")
  138. }
  139. # convert to function that takes only model params and inserts custom params
  140. function_name = spec["name"]
  141. tool_function = getattr(module, function_name)
  142. callable = get_async_tool_function_and_apply_extra_params(
  143. tool_function, extra_params
  144. )
  145. # TODO: Support Pydantic models as parameters
  146. if callable.__doc__ and callable.__doc__.strip() != "":
  147. s = re.split(":(param|return)", callable.__doc__, 1)
  148. spec["description"] = s[0]
  149. else:
  150. spec["description"] = function_name
  151. tool_dict = {
  152. "tool_id": tool_id,
  153. "callable": callable,
  154. "spec": spec,
  155. # Misc info
  156. "metadata": {
  157. "file_handler": hasattr(module, "file_handler")
  158. and module.file_handler,
  159. "citation": hasattr(module, "citation") and module.citation,
  160. },
  161. }
  162. # TODO: if collision, prepend toolkit name
  163. if function_name in tools_dict:
  164. log.warning(
  165. f"Tool {function_name} already exists in another tools!"
  166. )
  167. log.warning(f"Discarding {tool_id}.{function_name}")
  168. else:
  169. tools_dict[function_name] = tool_dict
  170. return tools_dict
  171. def parse_description(docstring: str | None) -> str:
  172. """
  173. Parse a function's docstring to extract the description.
  174. Args:
  175. docstring (str): The docstring to parse.
  176. Returns:
  177. str: The description.
  178. """
  179. if not docstring:
  180. return ""
  181. lines = [line.strip() for line in docstring.strip().split("\n")]
  182. description_lines: list[str] = []
  183. for line in lines:
  184. if re.match(r":param", line) or re.match(r":return", line):
  185. break
  186. description_lines.append(line)
  187. return "\n".join(description_lines)
  188. def parse_docstring(docstring):
  189. """
  190. Parse a function's docstring to extract parameter descriptions in reST format.
  191. Args:
  192. docstring (str): The docstring to parse.
  193. Returns:
  194. dict: A dictionary where keys are parameter names and values are descriptions.
  195. """
  196. if not docstring:
  197. return {}
  198. # Regex to match `:param name: description` format
  199. param_pattern = re.compile(r":param (\w+):\s*(.+)")
  200. param_descriptions = {}
  201. for line in docstring.splitlines():
  202. match = param_pattern.match(line.strip())
  203. if not match:
  204. continue
  205. param_name, param_description = match.groups()
  206. if param_name.startswith("__"):
  207. continue
  208. param_descriptions[param_name] = param_description
  209. return param_descriptions
  210. def convert_function_to_pydantic_model(func: Callable) -> type[BaseModel]:
  211. """
  212. Converts a Python function's type hints and docstring to a Pydantic model,
  213. including support for nested types, default values, and descriptions.
  214. Args:
  215. func: The function whose type hints and docstring should be converted.
  216. model_name: The name of the generated Pydantic model.
  217. Returns:
  218. A Pydantic model class.
  219. """
  220. type_hints = get_type_hints(func)
  221. signature = inspect.signature(func)
  222. parameters = signature.parameters
  223. docstring = func.__doc__
  224. description = parse_description(docstring)
  225. function_descriptions = parse_docstring(docstring)
  226. field_defs = {}
  227. for name, param in parameters.items():
  228. type_hint = type_hints.get(name, Any)
  229. default_value = param.default if param.default is not param.empty else ...
  230. description = function_descriptions.get(name, None)
  231. if description:
  232. field_defs[name] = type_hint, Field(default_value, description=description)
  233. else:
  234. field_defs[name] = type_hint, default_value
  235. model = create_model(func.__name__, **field_defs)
  236. model.__doc__ = description
  237. return model
  238. def get_functions_from_tool(tool: object) -> list[Callable]:
  239. return [
  240. getattr(tool, func)
  241. for func in dir(tool)
  242. if callable(
  243. getattr(tool, func)
  244. ) # checks if the attribute is callable (a method or function).
  245. and not func.startswith(
  246. "__"
  247. ) # filters out special (dunder) methods like init, str, etc. — these are usually built-in functions of an object that you might not need to use directly.
  248. and not inspect.isclass(
  249. getattr(tool, func)
  250. ) # ensures that the callable is not a class itself, just a method or function.
  251. ]
  252. def get_tool_specs(tool_module: object) -> list[dict]:
  253. function_models = map(
  254. convert_function_to_pydantic_model, get_functions_from_tool(tool_module)
  255. )
  256. specs = [
  257. convert_pydantic_model_to_openai_function_spec(function_model)
  258. for function_model in function_models
  259. ]
  260. return specs
  261. def resolve_schema(schema, components):
  262. """
  263. Recursively resolves a JSON schema using OpenAPI components.
  264. """
  265. if not schema:
  266. return {}
  267. if "$ref" in schema:
  268. ref_path = schema["$ref"]
  269. ref_parts = ref_path.strip("#/").split("/")
  270. resolved = components
  271. for part in ref_parts[1:]: # Skip the initial 'components'
  272. resolved = resolved.get(part, {})
  273. return resolve_schema(resolved, components)
  274. resolved_schema = copy.deepcopy(schema)
  275. # Recursively resolve inner schemas
  276. if "properties" in resolved_schema:
  277. for prop, prop_schema in resolved_schema["properties"].items():
  278. resolved_schema["properties"][prop] = resolve_schema(
  279. prop_schema, components
  280. )
  281. if "items" in resolved_schema:
  282. resolved_schema["items"] = resolve_schema(resolved_schema["items"], components)
  283. return resolved_schema
  284. def convert_openapi_to_tool_payload(openapi_spec):
  285. """
  286. Converts an OpenAPI specification into a custom tool payload structure.
  287. Args:
  288. openapi_spec (dict): The OpenAPI specification as a Python dict.
  289. Returns:
  290. list: A list of tool payloads.
  291. """
  292. tool_payload = []
  293. for path, methods in openapi_spec.get("paths", {}).items():
  294. for method, operation in methods.items():
  295. tool = {
  296. "type": "function",
  297. "name": operation.get("operationId"),
  298. "description": operation.get(
  299. "description", operation.get("summary", "No description available.")
  300. ),
  301. "parameters": {"type": "object", "properties": {}, "required": []},
  302. }
  303. # Extract path and query parameters
  304. for param in operation.get("parameters", []):
  305. param_name = param["name"]
  306. param_schema = param.get("schema", {})
  307. tool["parameters"]["properties"][param_name] = {
  308. "type": param_schema.get("type"),
  309. "description": param_schema.get("description", ""),
  310. }
  311. if param.get("required"):
  312. tool["parameters"]["required"].append(param_name)
  313. # Extract and resolve requestBody if available
  314. request_body = operation.get("requestBody")
  315. if request_body:
  316. content = request_body.get("content", {})
  317. json_schema = content.get("application/json", {}).get("schema")
  318. if json_schema:
  319. resolved_schema = resolve_schema(
  320. json_schema, openapi_spec.get("components", {})
  321. )
  322. if resolved_schema.get("properties"):
  323. tool["parameters"]["properties"].update(
  324. resolved_schema["properties"]
  325. )
  326. if "required" in resolved_schema:
  327. tool["parameters"]["required"] = list(
  328. set(
  329. tool["parameters"]["required"]
  330. + resolved_schema["required"]
  331. )
  332. )
  333. elif resolved_schema.get("type") == "array":
  334. tool["parameters"] = resolved_schema # special case for array
  335. tool_payload.append(tool)
  336. return tool_payload
  337. async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]:
  338. headers = {
  339. "Accept": "application/json",
  340. "Content-Type": "application/json",
  341. }
  342. if token:
  343. headers["Authorization"] = f"Bearer {token}"
  344. error = None
  345. try:
  346. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA)
  347. async with aiohttp.ClientSession(timeout=timeout) as session:
  348. async with session.get(url, headers=headers) as response:
  349. if response.status != 200:
  350. error_body = await response.json()
  351. raise Exception(error_body)
  352. # Check if URL ends with .yaml or .yml to determine format
  353. if url.lower().endswith((".yaml", ".yml")):
  354. text_content = await response.text()
  355. res = yaml.safe_load(text_content)
  356. else:
  357. res = await response.json()
  358. except Exception as err:
  359. log.exception(f"Could not fetch tool server spec from {url}")
  360. if isinstance(err, dict) and "detail" in err:
  361. error = err["detail"]
  362. else:
  363. error = str(err)
  364. raise Exception(error)
  365. data = {
  366. "openapi": res,
  367. "info": res.get("info", {}),
  368. "specs": convert_openapi_to_tool_payload(res),
  369. }
  370. print("Fetched data:", data)
  371. return data
  372. async def get_tool_servers_data(
  373. servers: List[Dict[str, Any]], session_token: Optional[str] = None
  374. ) -> List[Dict[str, Any]]:
  375. # Prepare list of enabled servers along with their original index
  376. server_entries = []
  377. for idx, server in enumerate(servers):
  378. if server.get("config", {}).get("enable"):
  379. url_path = server.get("path", "openapi.json")
  380. full_url = f"{server.get('url')}/{url_path}"
  381. auth_type = server.get("auth_type", "bearer")
  382. token = None
  383. if auth_type == "bearer":
  384. token = server.get("key", "")
  385. elif auth_type == "session":
  386. token = session_token
  387. server_entries.append((idx, server, full_url, token))
  388. # Create async tasks to fetch data
  389. tasks = [get_tool_server_data(token, url) for (_, _, url, token) in server_entries]
  390. # Execute tasks concurrently
  391. responses = await asyncio.gather(*tasks, return_exceptions=True)
  392. # Build final results with index and server metadata
  393. results = []
  394. for (idx, server, url, _), response in zip(server_entries, responses):
  395. if isinstance(response, Exception):
  396. print(f"Failed to connect to {url} OpenAPI tool server")
  397. continue
  398. results.append(
  399. {
  400. "idx": idx,
  401. "url": server.get("url"),
  402. "openapi": response.get("openapi"),
  403. "info": response.get("info"),
  404. "specs": response.get("specs"),
  405. }
  406. )
  407. return results
  408. async def execute_tool_server(
  409. token: str, url: str, name: str, params: Dict[str, Any], server_data: Dict[str, Any]
  410. ) -> Any:
  411. error = None
  412. try:
  413. openapi = server_data.get("openapi", {})
  414. paths = openapi.get("paths", {})
  415. matching_route = None
  416. for route_path, methods in paths.items():
  417. for http_method, operation in methods.items():
  418. if isinstance(operation, dict) and operation.get("operationId") == name:
  419. matching_route = (route_path, methods)
  420. break
  421. if matching_route:
  422. break
  423. if not matching_route:
  424. raise Exception(f"No matching route found for operationId: {name}")
  425. route_path, methods = matching_route
  426. method_entry = None
  427. for http_method, operation in methods.items():
  428. if operation.get("operationId") == name:
  429. method_entry = (http_method.lower(), operation)
  430. break
  431. if not method_entry:
  432. raise Exception(f"No matching method found for operationId: {name}")
  433. http_method, operation = method_entry
  434. path_params = {}
  435. query_params = {}
  436. body_params = {}
  437. for param in operation.get("parameters", []):
  438. param_name = param["name"]
  439. param_in = param["in"]
  440. if param_name in params:
  441. if param_in == "path":
  442. path_params[param_name] = params[param_name]
  443. elif param_in == "query":
  444. query_params[param_name] = params[param_name]
  445. final_url = f"{url}{route_path}"
  446. for key, value in path_params.items():
  447. final_url = final_url.replace(f"{{{key}}}", str(value))
  448. if query_params:
  449. query_string = "&".join(f"{k}={v}" for k, v in query_params.items())
  450. final_url = f"{final_url}?{query_string}"
  451. if operation.get("requestBody", {}).get("content"):
  452. if params:
  453. body_params = params
  454. else:
  455. raise Exception(
  456. f"Request body expected for operation '{name}' but none found."
  457. )
  458. headers = {"Content-Type": "application/json"}
  459. if token:
  460. headers["Authorization"] = f"Bearer {token}"
  461. async with aiohttp.ClientSession() as session:
  462. request_method = getattr(session, http_method.lower())
  463. if http_method in ["post", "put", "patch"]:
  464. async with request_method(
  465. final_url, json=body_params, headers=headers
  466. ) as response:
  467. if response.status >= 400:
  468. text = await response.text()
  469. raise Exception(f"HTTP error {response.status}: {text}")
  470. return await response.json()
  471. else:
  472. async with request_method(final_url, headers=headers) as response:
  473. if response.status >= 400:
  474. text = await response.text()
  475. raise Exception(f"HTTP error {response.status}: {text}")
  476. return await response.json()
  477. except Exception as err:
  478. error = str(err)
  479. print("API Request Error:", error)
  480. return {"error": error}