tools.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616
  1. import logging
  2. from pathlib import Path
  3. from typing import Optional
  4. import time
  5. import re
  6. import aiohttp
  7. from open_webui.models.groups import Groups
  8. from pydantic import BaseModel, HttpUrl
  9. from fastapi import APIRouter, Depends, HTTPException, Request, status
  10. from open_webui.models.oauth_sessions import OAuthSessions
  11. from open_webui.models.tools import (
  12. ToolForm,
  13. ToolModel,
  14. ToolResponse,
  15. ToolUserResponse,
  16. Tools,
  17. )
  18. from open_webui.utils.plugin import load_tool_module_by_id, replace_imports
  19. from open_webui.utils.tools import get_tool_specs
  20. from open_webui.utils.auth import get_admin_user, get_verified_user
  21. from open_webui.utils.access_control import has_access, has_permission
  22. from open_webui.utils.tools import get_tool_servers
  23. from open_webui.env import SRC_LOG_LEVELS
  24. from open_webui.config import CACHE_DIR, BYPASS_ADMIN_ACCESS_CONTROL
  25. from open_webui.constants import ERROR_MESSAGES
  26. log = logging.getLogger(__name__)
  27. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  28. router = APIRouter()
  29. ############################
  30. # GetTools
  31. ############################
  32. @router.get("/", response_model=list[ToolUserResponse])
  33. async def get_tools(request: Request, user=Depends(get_verified_user)):
  34. tools = [
  35. ToolUserResponse(
  36. **{
  37. **tool.model_dump(),
  38. "has_user_valves": "class UserValves(BaseModel):" in tool.content,
  39. }
  40. )
  41. for tool in Tools.get_tools()
  42. ]
  43. # OpenAPI Tool Servers
  44. for server in await get_tool_servers(request):
  45. tools.append(
  46. ToolUserResponse(
  47. **{
  48. "id": f"server:{server.get('id')}",
  49. "user_id": f"server:{server.get('id')}",
  50. "name": server.get("openapi", {})
  51. .get("info", {})
  52. .get("title", "Tool Server"),
  53. "meta": {
  54. "description": server.get("openapi", {})
  55. .get("info", {})
  56. .get("description", ""),
  57. },
  58. "access_control": request.app.state.config.TOOL_SERVER_CONNECTIONS[
  59. server.get("idx", 0)
  60. ]
  61. .get("config", {})
  62. .get("access_control", None),
  63. "updated_at": int(time.time()),
  64. "created_at": int(time.time()),
  65. }
  66. )
  67. )
  68. # MCP Tool Servers
  69. for server in request.app.state.config.TOOL_SERVER_CONNECTIONS:
  70. if server.get("type", "openapi") == "mcp":
  71. server_id = server.get("info", {}).get("id")
  72. auth_type = server.get("auth_type", "none")
  73. session_token = None
  74. if auth_type == "oauth_2.1":
  75. splits = server_id.split(":")
  76. server_id = splits[-1] if len(splits) > 1 else server_id
  77. session_token = (
  78. await request.app.state.oauth_client_manager.get_oauth_token(
  79. user.id, f"mcp:{server_id}"
  80. )
  81. )
  82. tools.append(
  83. ToolUserResponse(
  84. **{
  85. "id": f"server:mcp:{server.get('info', {}).get('id')}",
  86. "user_id": f"server:mcp:{server.get('info', {}).get('id')}",
  87. "name": server.get("info", {}).get("name", "MCP Tool Server"),
  88. "meta": {
  89. "description": server.get("info", {}).get(
  90. "description", ""
  91. ),
  92. },
  93. "access_control": server.get("config", {}).get(
  94. "access_control", None
  95. ),
  96. "updated_at": int(time.time()),
  97. "created_at": int(time.time()),
  98. **(
  99. {
  100. "authenticated": session_token is not None,
  101. }
  102. if auth_type == "oauth_2.1"
  103. else {}
  104. ),
  105. }
  106. )
  107. )
  108. if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
  109. # Admin can see all tools
  110. return tools
  111. else:
  112. user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)}
  113. tools = [
  114. tool
  115. for tool in tools
  116. if tool.user_id == user.id
  117. or has_access(user.id, "read", tool.access_control, user_group_ids)
  118. ]
  119. return tools
  120. ############################
  121. # GetToolList
  122. ############################
  123. @router.get("/list", response_model=list[ToolUserResponse])
  124. async def get_tool_list(user=Depends(get_verified_user)):
  125. if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
  126. tools = Tools.get_tools()
  127. else:
  128. tools = Tools.get_tools_by_user_id(user.id, "write")
  129. return tools
  130. ############################
  131. # LoadFunctionFromLink
  132. ############################
  133. class LoadUrlForm(BaseModel):
  134. url: HttpUrl
  135. def github_url_to_raw_url(url: str) -> str:
  136. # Handle 'tree' (folder) URLs (add main.py at the end)
  137. m1 = re.match(r"https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.*)", url)
  138. if m1:
  139. org, repo, branch, path = m1.groups()
  140. return f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path.rstrip('/')}/main.py"
  141. # Handle 'blob' (file) URLs
  142. m2 = re.match(r"https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.*)", url)
  143. if m2:
  144. org, repo, branch, path = m2.groups()
  145. return (
  146. f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path}"
  147. )
  148. # No match; return as-is
  149. return url
  150. @router.post("/load/url", response_model=Optional[dict])
  151. async def load_tool_from_url(
  152. request: Request, form_data: LoadUrlForm, user=Depends(get_admin_user)
  153. ):
  154. # NOTE: This is NOT a SSRF vulnerability:
  155. # This endpoint is admin-only (see get_admin_user), meant for *trusted* internal use,
  156. # and does NOT accept untrusted user input. Access is enforced by authentication.
  157. url = str(form_data.url)
  158. if not url:
  159. raise HTTPException(status_code=400, detail="Please enter a valid URL")
  160. url = github_url_to_raw_url(url)
  161. url_parts = url.rstrip("/").split("/")
  162. file_name = url_parts[-1]
  163. tool_name = (
  164. file_name[:-3]
  165. if (
  166. file_name.endswith(".py")
  167. and (not file_name.startswith(("main.py", "index.py", "__init__.py")))
  168. )
  169. else url_parts[-2] if len(url_parts) > 1 else "function"
  170. )
  171. try:
  172. async with aiohttp.ClientSession(trust_env=True) as session:
  173. async with session.get(
  174. url, headers={"Content-Type": "application/json"}
  175. ) as resp:
  176. if resp.status != 200:
  177. raise HTTPException(
  178. status_code=resp.status, detail="Failed to fetch the tool"
  179. )
  180. data = await resp.text()
  181. if not data:
  182. raise HTTPException(
  183. status_code=400, detail="No data received from the URL"
  184. )
  185. return {
  186. "name": tool_name,
  187. "content": data,
  188. }
  189. except Exception as e:
  190. raise HTTPException(status_code=500, detail=f"Error importing tool: {e}")
  191. ############################
  192. # ExportTools
  193. ############################
  194. @router.get("/export", response_model=list[ToolModel])
  195. async def export_tools(user=Depends(get_admin_user)):
  196. tools = Tools.get_tools()
  197. return tools
  198. ############################
  199. # CreateNewTools
  200. ############################
  201. @router.post("/create", response_model=Optional[ToolResponse])
  202. async def create_new_tools(
  203. request: Request,
  204. form_data: ToolForm,
  205. user=Depends(get_verified_user),
  206. ):
  207. if user.role != "admin" and not has_permission(
  208. user.id, "workspace.tools", request.app.state.config.USER_PERMISSIONS
  209. ):
  210. raise HTTPException(
  211. status_code=status.HTTP_401_UNAUTHORIZED,
  212. detail=ERROR_MESSAGES.UNAUTHORIZED,
  213. )
  214. if not form_data.id.isidentifier():
  215. raise HTTPException(
  216. status_code=status.HTTP_400_BAD_REQUEST,
  217. detail="Only alphanumeric characters and underscores are allowed in the id",
  218. )
  219. form_data.id = form_data.id.lower()
  220. tools = Tools.get_tool_by_id(form_data.id)
  221. if tools is None:
  222. try:
  223. form_data.content = replace_imports(form_data.content)
  224. tool_module, frontmatter = load_tool_module_by_id(
  225. form_data.id, content=form_data.content
  226. )
  227. form_data.meta.manifest = frontmatter
  228. TOOLS = request.app.state.TOOLS
  229. TOOLS[form_data.id] = tool_module
  230. specs = get_tool_specs(TOOLS[form_data.id])
  231. tools = Tools.insert_new_tool(user.id, form_data, specs)
  232. tool_cache_dir = CACHE_DIR / "tools" / form_data.id
  233. tool_cache_dir.mkdir(parents=True, exist_ok=True)
  234. if tools:
  235. return tools
  236. else:
  237. raise HTTPException(
  238. status_code=status.HTTP_400_BAD_REQUEST,
  239. detail=ERROR_MESSAGES.DEFAULT("Error creating tools"),
  240. )
  241. except Exception as e:
  242. log.exception(f"Failed to load the tool by id {form_data.id}: {e}")
  243. raise HTTPException(
  244. status_code=status.HTTP_400_BAD_REQUEST,
  245. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  246. )
  247. else:
  248. raise HTTPException(
  249. status_code=status.HTTP_400_BAD_REQUEST,
  250. detail=ERROR_MESSAGES.ID_TAKEN,
  251. )
  252. ############################
  253. # GetToolsById
  254. ############################
  255. @router.get("/id/{id}", response_model=Optional[ToolModel])
  256. async def get_tools_by_id(id: str, user=Depends(get_verified_user)):
  257. tools = Tools.get_tool_by_id(id)
  258. if tools:
  259. if (
  260. user.role == "admin"
  261. or tools.user_id == user.id
  262. or has_access(user.id, "read", tools.access_control)
  263. ):
  264. return tools
  265. else:
  266. raise HTTPException(
  267. status_code=status.HTTP_401_UNAUTHORIZED,
  268. detail=ERROR_MESSAGES.NOT_FOUND,
  269. )
  270. ############################
  271. # UpdateToolsById
  272. ############################
  273. @router.post("/id/{id}/update", response_model=Optional[ToolModel])
  274. async def update_tools_by_id(
  275. request: Request,
  276. id: str,
  277. form_data: ToolForm,
  278. user=Depends(get_verified_user),
  279. ):
  280. tools = Tools.get_tool_by_id(id)
  281. if not tools:
  282. raise HTTPException(
  283. status_code=status.HTTP_401_UNAUTHORIZED,
  284. detail=ERROR_MESSAGES.NOT_FOUND,
  285. )
  286. # Is the user the original creator, in a group with write access, or an admin
  287. if (
  288. tools.user_id != user.id
  289. and not has_access(user.id, "write", tools.access_control)
  290. and user.role != "admin"
  291. ):
  292. raise HTTPException(
  293. status_code=status.HTTP_401_UNAUTHORIZED,
  294. detail=ERROR_MESSAGES.UNAUTHORIZED,
  295. )
  296. try:
  297. form_data.content = replace_imports(form_data.content)
  298. tool_module, frontmatter = load_tool_module_by_id(id, content=form_data.content)
  299. form_data.meta.manifest = frontmatter
  300. TOOLS = request.app.state.TOOLS
  301. TOOLS[id] = tool_module
  302. specs = get_tool_specs(TOOLS[id])
  303. updated = {
  304. **form_data.model_dump(exclude={"id"}),
  305. "specs": specs,
  306. }
  307. log.debug(updated)
  308. tools = Tools.update_tool_by_id(id, updated)
  309. if tools:
  310. return tools
  311. else:
  312. raise HTTPException(
  313. status_code=status.HTTP_400_BAD_REQUEST,
  314. detail=ERROR_MESSAGES.DEFAULT("Error updating tools"),
  315. )
  316. except Exception as e:
  317. raise HTTPException(
  318. status_code=status.HTTP_400_BAD_REQUEST,
  319. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  320. )
  321. ############################
  322. # DeleteToolsById
  323. ############################
  324. @router.delete("/id/{id}/delete", response_model=bool)
  325. async def delete_tools_by_id(
  326. request: Request, id: str, user=Depends(get_verified_user)
  327. ):
  328. tools = Tools.get_tool_by_id(id)
  329. if not tools:
  330. raise HTTPException(
  331. status_code=status.HTTP_401_UNAUTHORIZED,
  332. detail=ERROR_MESSAGES.NOT_FOUND,
  333. )
  334. if (
  335. tools.user_id != user.id
  336. and not has_access(user.id, "write", tools.access_control)
  337. and user.role != "admin"
  338. ):
  339. raise HTTPException(
  340. status_code=status.HTTP_401_UNAUTHORIZED,
  341. detail=ERROR_MESSAGES.UNAUTHORIZED,
  342. )
  343. result = Tools.delete_tool_by_id(id)
  344. if result:
  345. TOOLS = request.app.state.TOOLS
  346. if id in TOOLS:
  347. del TOOLS[id]
  348. return result
  349. ############################
  350. # GetToolValves
  351. ############################
  352. @router.get("/id/{id}/valves", response_model=Optional[dict])
  353. async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
  354. tools = Tools.get_tool_by_id(id)
  355. if tools:
  356. try:
  357. valves = Tools.get_tool_valves_by_id(id)
  358. return valves
  359. except Exception as e:
  360. raise HTTPException(
  361. status_code=status.HTTP_400_BAD_REQUEST,
  362. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  363. )
  364. else:
  365. raise HTTPException(
  366. status_code=status.HTTP_401_UNAUTHORIZED,
  367. detail=ERROR_MESSAGES.NOT_FOUND,
  368. )
  369. ############################
  370. # GetToolValvesSpec
  371. ############################
  372. @router.get("/id/{id}/valves/spec", response_model=Optional[dict])
  373. async def get_tools_valves_spec_by_id(
  374. request: Request, id: str, user=Depends(get_verified_user)
  375. ):
  376. tools = Tools.get_tool_by_id(id)
  377. if tools:
  378. if id in request.app.state.TOOLS:
  379. tools_module = request.app.state.TOOLS[id]
  380. else:
  381. tools_module, _ = load_tool_module_by_id(id)
  382. request.app.state.TOOLS[id] = tools_module
  383. if hasattr(tools_module, "Valves"):
  384. Valves = tools_module.Valves
  385. return Valves.schema()
  386. return None
  387. else:
  388. raise HTTPException(
  389. status_code=status.HTTP_401_UNAUTHORIZED,
  390. detail=ERROR_MESSAGES.NOT_FOUND,
  391. )
  392. ############################
  393. # UpdateToolValves
  394. ############################
  395. @router.post("/id/{id}/valves/update", response_model=Optional[dict])
  396. async def update_tools_valves_by_id(
  397. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  398. ):
  399. tools = Tools.get_tool_by_id(id)
  400. if not tools:
  401. raise HTTPException(
  402. status_code=status.HTTP_401_UNAUTHORIZED,
  403. detail=ERROR_MESSAGES.NOT_FOUND,
  404. )
  405. if (
  406. tools.user_id != user.id
  407. and not has_access(user.id, "write", tools.access_control)
  408. and user.role != "admin"
  409. ):
  410. raise HTTPException(
  411. status_code=status.HTTP_400_BAD_REQUEST,
  412. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  413. )
  414. if id in request.app.state.TOOLS:
  415. tools_module = request.app.state.TOOLS[id]
  416. else:
  417. tools_module, _ = load_tool_module_by_id(id)
  418. request.app.state.TOOLS[id] = tools_module
  419. if not hasattr(tools_module, "Valves"):
  420. raise HTTPException(
  421. status_code=status.HTTP_401_UNAUTHORIZED,
  422. detail=ERROR_MESSAGES.NOT_FOUND,
  423. )
  424. Valves = tools_module.Valves
  425. try:
  426. form_data = {k: v for k, v in form_data.items() if v is not None}
  427. valves = Valves(**form_data)
  428. valves_dict = valves.model_dump(exclude_unset=True)
  429. Tools.update_tool_valves_by_id(id, valves_dict)
  430. return valves_dict
  431. except Exception as e:
  432. log.exception(f"Failed to update tool valves by id {id}: {e}")
  433. raise HTTPException(
  434. status_code=status.HTTP_400_BAD_REQUEST,
  435. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  436. )
  437. ############################
  438. # ToolUserValves
  439. ############################
  440. @router.get("/id/{id}/valves/user", response_model=Optional[dict])
  441. async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
  442. tools = Tools.get_tool_by_id(id)
  443. if tools:
  444. try:
  445. user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
  446. return user_valves
  447. except Exception as e:
  448. raise HTTPException(
  449. status_code=status.HTTP_400_BAD_REQUEST,
  450. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  451. )
  452. else:
  453. raise HTTPException(
  454. status_code=status.HTTP_401_UNAUTHORIZED,
  455. detail=ERROR_MESSAGES.NOT_FOUND,
  456. )
  457. @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
  458. async def get_tools_user_valves_spec_by_id(
  459. request: Request, id: str, user=Depends(get_verified_user)
  460. ):
  461. tools = Tools.get_tool_by_id(id)
  462. if tools:
  463. if id in request.app.state.TOOLS:
  464. tools_module = request.app.state.TOOLS[id]
  465. else:
  466. tools_module, _ = load_tool_module_by_id(id)
  467. request.app.state.TOOLS[id] = tools_module
  468. if hasattr(tools_module, "UserValves"):
  469. UserValves = tools_module.UserValves
  470. return UserValves.schema()
  471. return None
  472. else:
  473. raise HTTPException(
  474. status_code=status.HTTP_401_UNAUTHORIZED,
  475. detail=ERROR_MESSAGES.NOT_FOUND,
  476. )
  477. @router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
  478. async def update_tools_user_valves_by_id(
  479. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  480. ):
  481. tools = Tools.get_tool_by_id(id)
  482. if tools:
  483. if id in request.app.state.TOOLS:
  484. tools_module = request.app.state.TOOLS[id]
  485. else:
  486. tools_module, _ = load_tool_module_by_id(id)
  487. request.app.state.TOOLS[id] = tools_module
  488. if hasattr(tools_module, "UserValves"):
  489. UserValves = tools_module.UserValves
  490. try:
  491. form_data = {k: v for k, v in form_data.items() if v is not None}
  492. user_valves = UserValves(**form_data)
  493. user_valves_dict = user_valves.model_dump(exclude_unset=True)
  494. Tools.update_user_valves_by_id_and_user_id(
  495. id, user.id, user_valves_dict
  496. )
  497. return user_valves_dict
  498. except Exception as e:
  499. log.exception(f"Failed to update user valves by id {id}: {e}")
  500. raise HTTPException(
  501. status_code=status.HTTP_400_BAD_REQUEST,
  502. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  503. )
  504. else:
  505. raise HTTPException(
  506. status_code=status.HTTP_401_UNAUTHORIZED,
  507. detail=ERROR_MESSAGES.NOT_FOUND,
  508. )
  509. else:
  510. raise HTTPException(
  511. status_code=status.HTTP_401_UNAUTHORIZED,
  512. detail=ERROR_MESSAGES.NOT_FOUND,
  513. )