tools.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632
  1. import logging
  2. from pathlib import Path
  3. from typing import Optional
  4. import time
  5. import re
  6. import aiohttp
  7. from open_webui.models.groups import Groups
  8. from pydantic import BaseModel, HttpUrl
  9. from fastapi import APIRouter, Depends, HTTPException, Request, status
  10. from open_webui.models.oauth_sessions import OAuthSessions
  11. from open_webui.models.tools import (
  12. ToolForm,
  13. ToolModel,
  14. ToolResponse,
  15. ToolUserResponse,
  16. Tools,
  17. )
  18. from open_webui.utils.plugin import (
  19. load_tool_module_by_id,
  20. replace_imports,
  21. get_tool_module_from_cache,
  22. )
  23. from open_webui.utils.tools import get_tool_specs
  24. from open_webui.utils.auth import get_admin_user, get_verified_user
  25. from open_webui.utils.access_control import has_access, has_permission
  26. from open_webui.utils.tools import get_tool_servers
  27. from open_webui.env import SRC_LOG_LEVELS
  28. from open_webui.config import CACHE_DIR, BYPASS_ADMIN_ACCESS_CONTROL
  29. from open_webui.constants import ERROR_MESSAGES
  30. log = logging.getLogger(__name__)
  31. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  32. router = APIRouter()
  33. def get_tool_module(request, tool_id, load_from_db=True):
  34. """
  35. Get the tool module by its ID.
  36. """
  37. tool_module, _ = get_tool_module_from_cache(request, tool_id, load_from_db)
  38. return tool_module
  39. ############################
  40. # GetTools
  41. ############################
  42. @router.get("/", response_model=list[ToolUserResponse])
  43. async def get_tools(request: Request, user=Depends(get_verified_user)):
  44. tools = []
  45. # Local Tools
  46. for tool in Tools.get_tools():
  47. tool_module = get_tool_module(request, tool.id)
  48. tools.append(
  49. ToolUserResponse(
  50. **{
  51. **tool.model_dump(),
  52. "has_user_valves": hasattr(tool_module, "UserValves"),
  53. }
  54. )
  55. )
  56. # OpenAPI Tool Servers
  57. for server in await get_tool_servers(request):
  58. tools.append(
  59. ToolUserResponse(
  60. **{
  61. "id": f"server:{server.get('id')}",
  62. "user_id": f"server:{server.get('id')}",
  63. "name": server.get("openapi", {})
  64. .get("info", {})
  65. .get("title", "Tool Server"),
  66. "meta": {
  67. "description": server.get("openapi", {})
  68. .get("info", {})
  69. .get("description", ""),
  70. },
  71. "access_control": request.app.state.config.TOOL_SERVER_CONNECTIONS[
  72. server.get("idx", 0)
  73. ]
  74. .get("config", {})
  75. .get("access_control", None),
  76. "updated_at": int(time.time()),
  77. "created_at": int(time.time()),
  78. }
  79. )
  80. )
  81. # MCP Tool Servers
  82. for server in request.app.state.config.TOOL_SERVER_CONNECTIONS:
  83. if server.get("type", "openapi") == "mcp":
  84. server_id = server.get("info", {}).get("id")
  85. auth_type = server.get("auth_type", "none")
  86. session_token = None
  87. if auth_type == "oauth_2.1":
  88. splits = server_id.split(":")
  89. server_id = splits[-1] if len(splits) > 1 else server_id
  90. session_token = (
  91. await request.app.state.oauth_client_manager.get_oauth_token(
  92. user.id, f"mcp:{server_id}"
  93. )
  94. )
  95. tools.append(
  96. ToolUserResponse(
  97. **{
  98. "id": f"server:mcp:{server.get('info', {}).get('id')}",
  99. "user_id": f"server:mcp:{server.get('info', {}).get('id')}",
  100. "name": server.get("info", {}).get("name", "MCP Tool Server"),
  101. "meta": {
  102. "description": server.get("info", {}).get(
  103. "description", ""
  104. ),
  105. },
  106. "access_control": server.get("config", {}).get(
  107. "access_control", None
  108. ),
  109. "updated_at": int(time.time()),
  110. "created_at": int(time.time()),
  111. **(
  112. {
  113. "authenticated": session_token is not None,
  114. }
  115. if auth_type == "oauth_2.1"
  116. else {}
  117. ),
  118. }
  119. )
  120. )
  121. if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
  122. # Admin can see all tools
  123. return tools
  124. else:
  125. user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)}
  126. tools = [
  127. tool
  128. for tool in tools
  129. if tool.user_id == user.id
  130. or has_access(user.id, "read", tool.access_control, user_group_ids)
  131. ]
  132. return tools
  133. ############################
  134. # GetToolList
  135. ############################
  136. @router.get("/list", response_model=list[ToolUserResponse])
  137. async def get_tool_list(user=Depends(get_verified_user)):
  138. if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
  139. tools = Tools.get_tools()
  140. else:
  141. tools = Tools.get_tools_by_user_id(user.id, "write")
  142. return tools
  143. ############################
  144. # LoadFunctionFromLink
  145. ############################
  146. class LoadUrlForm(BaseModel):
  147. url: HttpUrl
  148. def github_url_to_raw_url(url: str) -> str:
  149. # Handle 'tree' (folder) URLs (add main.py at the end)
  150. m1 = re.match(r"https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.*)", url)
  151. if m1:
  152. org, repo, branch, path = m1.groups()
  153. return f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path.rstrip('/')}/main.py"
  154. # Handle 'blob' (file) URLs
  155. m2 = re.match(r"https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.*)", url)
  156. if m2:
  157. org, repo, branch, path = m2.groups()
  158. return (
  159. f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path}"
  160. )
  161. # No match; return as-is
  162. return url
  163. @router.post("/load/url", response_model=Optional[dict])
  164. async def load_tool_from_url(
  165. request: Request, form_data: LoadUrlForm, user=Depends(get_admin_user)
  166. ):
  167. # NOTE: This is NOT a SSRF vulnerability:
  168. # This endpoint is admin-only (see get_admin_user), meant for *trusted* internal use,
  169. # and does NOT accept untrusted user input. Access is enforced by authentication.
  170. url = str(form_data.url)
  171. if not url:
  172. raise HTTPException(status_code=400, detail="Please enter a valid URL")
  173. url = github_url_to_raw_url(url)
  174. url_parts = url.rstrip("/").split("/")
  175. file_name = url_parts[-1]
  176. tool_name = (
  177. file_name[:-3]
  178. if (
  179. file_name.endswith(".py")
  180. and (not file_name.startswith(("main.py", "index.py", "__init__.py")))
  181. )
  182. else url_parts[-2] if len(url_parts) > 1 else "function"
  183. )
  184. try:
  185. async with aiohttp.ClientSession(trust_env=True) as session:
  186. async with session.get(
  187. url, headers={"Content-Type": "application/json"}
  188. ) as resp:
  189. if resp.status != 200:
  190. raise HTTPException(
  191. status_code=resp.status, detail="Failed to fetch the tool"
  192. )
  193. data = await resp.text()
  194. if not data:
  195. raise HTTPException(
  196. status_code=400, detail="No data received from the URL"
  197. )
  198. return {
  199. "name": tool_name,
  200. "content": data,
  201. }
  202. except Exception as e:
  203. raise HTTPException(status_code=500, detail=f"Error importing tool: {e}")
  204. ############################
  205. # ExportTools
  206. ############################
  207. @router.get("/export", response_model=list[ToolModel])
  208. async def export_tools(user=Depends(get_admin_user)):
  209. tools = Tools.get_tools()
  210. return tools
  211. ############################
  212. # CreateNewTools
  213. ############################
  214. @router.post("/create", response_model=Optional[ToolResponse])
  215. async def create_new_tools(
  216. request: Request,
  217. form_data: ToolForm,
  218. user=Depends(get_verified_user),
  219. ):
  220. if user.role != "admin" and not has_permission(
  221. user.id, "workspace.tools", request.app.state.config.USER_PERMISSIONS
  222. ):
  223. raise HTTPException(
  224. status_code=status.HTTP_401_UNAUTHORIZED,
  225. detail=ERROR_MESSAGES.UNAUTHORIZED,
  226. )
  227. if not form_data.id.isidentifier():
  228. raise HTTPException(
  229. status_code=status.HTTP_400_BAD_REQUEST,
  230. detail="Only alphanumeric characters and underscores are allowed in the id",
  231. )
  232. form_data.id = form_data.id.lower()
  233. tools = Tools.get_tool_by_id(form_data.id)
  234. if tools is None:
  235. try:
  236. form_data.content = replace_imports(form_data.content)
  237. tool_module, frontmatter = load_tool_module_by_id(
  238. form_data.id, content=form_data.content
  239. )
  240. form_data.meta.manifest = frontmatter
  241. TOOLS = request.app.state.TOOLS
  242. TOOLS[form_data.id] = tool_module
  243. specs = get_tool_specs(TOOLS[form_data.id])
  244. tools = Tools.insert_new_tool(user.id, form_data, specs)
  245. tool_cache_dir = CACHE_DIR / "tools" / form_data.id
  246. tool_cache_dir.mkdir(parents=True, exist_ok=True)
  247. if tools:
  248. return tools
  249. else:
  250. raise HTTPException(
  251. status_code=status.HTTP_400_BAD_REQUEST,
  252. detail=ERROR_MESSAGES.DEFAULT("Error creating tools"),
  253. )
  254. except Exception as e:
  255. log.exception(f"Failed to load the tool by id {form_data.id}: {e}")
  256. raise HTTPException(
  257. status_code=status.HTTP_400_BAD_REQUEST,
  258. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  259. )
  260. else:
  261. raise HTTPException(
  262. status_code=status.HTTP_400_BAD_REQUEST,
  263. detail=ERROR_MESSAGES.ID_TAKEN,
  264. )
  265. ############################
  266. # GetToolsById
  267. ############################
  268. @router.get("/id/{id}", response_model=Optional[ToolModel])
  269. async def get_tools_by_id(id: str, user=Depends(get_verified_user)):
  270. tools = Tools.get_tool_by_id(id)
  271. if tools:
  272. if (
  273. user.role == "admin"
  274. or tools.user_id == user.id
  275. or has_access(user.id, "read", tools.access_control)
  276. ):
  277. return tools
  278. else:
  279. raise HTTPException(
  280. status_code=status.HTTP_401_UNAUTHORIZED,
  281. detail=ERROR_MESSAGES.NOT_FOUND,
  282. )
  283. ############################
  284. # UpdateToolsById
  285. ############################
  286. @router.post("/id/{id}/update", response_model=Optional[ToolModel])
  287. async def update_tools_by_id(
  288. request: Request,
  289. id: str,
  290. form_data: ToolForm,
  291. user=Depends(get_verified_user),
  292. ):
  293. tools = Tools.get_tool_by_id(id)
  294. if not tools:
  295. raise HTTPException(
  296. status_code=status.HTTP_401_UNAUTHORIZED,
  297. detail=ERROR_MESSAGES.NOT_FOUND,
  298. )
  299. # Is the user the original creator, in a group with write access, or an admin
  300. if (
  301. tools.user_id != user.id
  302. and not has_access(user.id, "write", tools.access_control)
  303. and user.role != "admin"
  304. ):
  305. raise HTTPException(
  306. status_code=status.HTTP_401_UNAUTHORIZED,
  307. detail=ERROR_MESSAGES.UNAUTHORIZED,
  308. )
  309. try:
  310. form_data.content = replace_imports(form_data.content)
  311. tool_module, frontmatter = load_tool_module_by_id(id, content=form_data.content)
  312. form_data.meta.manifest = frontmatter
  313. TOOLS = request.app.state.TOOLS
  314. TOOLS[id] = tool_module
  315. specs = get_tool_specs(TOOLS[id])
  316. updated = {
  317. **form_data.model_dump(exclude={"id"}),
  318. "specs": specs,
  319. }
  320. log.debug(updated)
  321. tools = Tools.update_tool_by_id(id, updated)
  322. if tools:
  323. return tools
  324. else:
  325. raise HTTPException(
  326. status_code=status.HTTP_400_BAD_REQUEST,
  327. detail=ERROR_MESSAGES.DEFAULT("Error updating tools"),
  328. )
  329. except Exception as e:
  330. raise HTTPException(
  331. status_code=status.HTTP_400_BAD_REQUEST,
  332. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  333. )
  334. ############################
  335. # DeleteToolsById
  336. ############################
  337. @router.delete("/id/{id}/delete", response_model=bool)
  338. async def delete_tools_by_id(
  339. request: Request, id: str, user=Depends(get_verified_user)
  340. ):
  341. tools = Tools.get_tool_by_id(id)
  342. if not tools:
  343. raise HTTPException(
  344. status_code=status.HTTP_401_UNAUTHORIZED,
  345. detail=ERROR_MESSAGES.NOT_FOUND,
  346. )
  347. if (
  348. tools.user_id != user.id
  349. and not has_access(user.id, "write", tools.access_control)
  350. and user.role != "admin"
  351. ):
  352. raise HTTPException(
  353. status_code=status.HTTP_401_UNAUTHORIZED,
  354. detail=ERROR_MESSAGES.UNAUTHORIZED,
  355. )
  356. result = Tools.delete_tool_by_id(id)
  357. if result:
  358. TOOLS = request.app.state.TOOLS
  359. if id in TOOLS:
  360. del TOOLS[id]
  361. return result
  362. ############################
  363. # GetToolValves
  364. ############################
  365. @router.get("/id/{id}/valves", response_model=Optional[dict])
  366. async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
  367. tools = Tools.get_tool_by_id(id)
  368. if tools:
  369. try:
  370. valves = Tools.get_tool_valves_by_id(id)
  371. return valves
  372. except Exception as e:
  373. raise HTTPException(
  374. status_code=status.HTTP_400_BAD_REQUEST,
  375. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  376. )
  377. else:
  378. raise HTTPException(
  379. status_code=status.HTTP_401_UNAUTHORIZED,
  380. detail=ERROR_MESSAGES.NOT_FOUND,
  381. )
  382. ############################
  383. # GetToolValvesSpec
  384. ############################
  385. @router.get("/id/{id}/valves/spec", response_model=Optional[dict])
  386. async def get_tools_valves_spec_by_id(
  387. request: Request, id: str, user=Depends(get_verified_user)
  388. ):
  389. tools = Tools.get_tool_by_id(id)
  390. if tools:
  391. if id in request.app.state.TOOLS:
  392. tools_module = request.app.state.TOOLS[id]
  393. else:
  394. tools_module, _ = load_tool_module_by_id(id)
  395. request.app.state.TOOLS[id] = tools_module
  396. if hasattr(tools_module, "Valves"):
  397. Valves = tools_module.Valves
  398. return Valves.schema()
  399. return None
  400. else:
  401. raise HTTPException(
  402. status_code=status.HTTP_401_UNAUTHORIZED,
  403. detail=ERROR_MESSAGES.NOT_FOUND,
  404. )
  405. ############################
  406. # UpdateToolValves
  407. ############################
  408. @router.post("/id/{id}/valves/update", response_model=Optional[dict])
  409. async def update_tools_valves_by_id(
  410. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  411. ):
  412. tools = Tools.get_tool_by_id(id)
  413. if not tools:
  414. raise HTTPException(
  415. status_code=status.HTTP_401_UNAUTHORIZED,
  416. detail=ERROR_MESSAGES.NOT_FOUND,
  417. )
  418. if (
  419. tools.user_id != user.id
  420. and not has_access(user.id, "write", tools.access_control)
  421. and user.role != "admin"
  422. ):
  423. raise HTTPException(
  424. status_code=status.HTTP_400_BAD_REQUEST,
  425. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  426. )
  427. if id in request.app.state.TOOLS:
  428. tools_module = request.app.state.TOOLS[id]
  429. else:
  430. tools_module, _ = load_tool_module_by_id(id)
  431. request.app.state.TOOLS[id] = tools_module
  432. if not hasattr(tools_module, "Valves"):
  433. raise HTTPException(
  434. status_code=status.HTTP_401_UNAUTHORIZED,
  435. detail=ERROR_MESSAGES.NOT_FOUND,
  436. )
  437. Valves = tools_module.Valves
  438. try:
  439. form_data = {k: v for k, v in form_data.items() if v is not None}
  440. valves = Valves(**form_data)
  441. valves_dict = valves.model_dump(exclude_unset=True)
  442. Tools.update_tool_valves_by_id(id, valves_dict)
  443. return valves_dict
  444. except Exception as e:
  445. log.exception(f"Failed to update tool valves by id {id}: {e}")
  446. raise HTTPException(
  447. status_code=status.HTTP_400_BAD_REQUEST,
  448. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  449. )
  450. ############################
  451. # ToolUserValves
  452. ############################
  453. @router.get("/id/{id}/valves/user", response_model=Optional[dict])
  454. async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
  455. tools = Tools.get_tool_by_id(id)
  456. if tools:
  457. try:
  458. user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
  459. return user_valves
  460. except Exception as e:
  461. raise HTTPException(
  462. status_code=status.HTTP_400_BAD_REQUEST,
  463. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  464. )
  465. else:
  466. raise HTTPException(
  467. status_code=status.HTTP_401_UNAUTHORIZED,
  468. detail=ERROR_MESSAGES.NOT_FOUND,
  469. )
  470. @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
  471. async def get_tools_user_valves_spec_by_id(
  472. request: Request, id: str, user=Depends(get_verified_user)
  473. ):
  474. tools = Tools.get_tool_by_id(id)
  475. if tools:
  476. if id in request.app.state.TOOLS:
  477. tools_module = request.app.state.TOOLS[id]
  478. else:
  479. tools_module, _ = load_tool_module_by_id(id)
  480. request.app.state.TOOLS[id] = tools_module
  481. if hasattr(tools_module, "UserValves"):
  482. UserValves = tools_module.UserValves
  483. return UserValves.schema()
  484. return None
  485. else:
  486. raise HTTPException(
  487. status_code=status.HTTP_401_UNAUTHORIZED,
  488. detail=ERROR_MESSAGES.NOT_FOUND,
  489. )
  490. @router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
  491. async def update_tools_user_valves_by_id(
  492. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  493. ):
  494. tools = Tools.get_tool_by_id(id)
  495. if tools:
  496. if id in request.app.state.TOOLS:
  497. tools_module = request.app.state.TOOLS[id]
  498. else:
  499. tools_module, _ = load_tool_module_by_id(id)
  500. request.app.state.TOOLS[id] = tools_module
  501. if hasattr(tools_module, "UserValves"):
  502. UserValves = tools_module.UserValves
  503. try:
  504. form_data = {k: v for k, v in form_data.items() if v is not None}
  505. user_valves = UserValves(**form_data)
  506. user_valves_dict = user_valves.model_dump(exclude_unset=True)
  507. Tools.update_user_valves_by_id_and_user_id(
  508. id, user.id, user_valves_dict
  509. )
  510. return user_valves_dict
  511. except Exception as e:
  512. log.exception(f"Failed to update user valves by id {id}: {e}")
  513. raise HTTPException(
  514. status_code=status.HTTP_400_BAD_REQUEST,
  515. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  516. )
  517. else:
  518. raise HTTPException(
  519. status_code=status.HTTP_401_UNAUTHORIZED,
  520. detail=ERROR_MESSAGES.NOT_FOUND,
  521. )
  522. else:
  523. raise HTTPException(
  524. status_code=status.HTTP_401_UNAUTHORIZED,
  525. detail=ERROR_MESSAGES.NOT_FOUND,
  526. )