tools.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620
  1. import logging
  2. from pathlib import Path
  3. from typing import Optional
  4. import time
  5. import re
  6. import aiohttp
  7. from open_webui.models.groups import Groups
  8. from pydantic import BaseModel, HttpUrl
  9. from fastapi import APIRouter, Depends, HTTPException, Request, status
  10. from open_webui.models.oauth_sessions import OAuthSessions
  11. from open_webui.models.tools import (
  12. ToolForm,
  13. ToolModel,
  14. ToolResponse,
  15. ToolUserResponse,
  16. Tools,
  17. )
  18. from open_webui.utils.plugin import load_tool_module_by_id, replace_imports
  19. from open_webui.utils.tools import get_tool_specs
  20. from open_webui.utils.auth import get_admin_user, get_verified_user
  21. from open_webui.utils.access_control import has_access, has_permission
  22. from open_webui.utils.tools import get_tool_servers
  23. from open_webui.env import SRC_LOG_LEVELS
  24. from open_webui.config import CACHE_DIR, BYPASS_ADMIN_ACCESS_CONTROL
  25. from open_webui.constants import ERROR_MESSAGES
  26. log = logging.getLogger(__name__)
  27. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  28. router = APIRouter()
  29. ############################
  30. # GetTools
  31. ############################
  32. @router.get("/", response_model=list[ToolUserResponse])
  33. async def get_tools(request: Request, user=Depends(get_verified_user)):
  34. tools = [
  35. ToolUserResponse(
  36. **{
  37. **tool.model_dump(),
  38. "has_user_valves": "class UserValves(BaseModel):" in tool.content,
  39. }
  40. )
  41. for tool in Tools.get_tools()
  42. ]
  43. # OpenAPI Tool Servers
  44. for server in await get_tool_servers(request):
  45. tools.append(
  46. ToolUserResponse(
  47. **{
  48. "id": f"server:{server.get('id')}",
  49. "user_id": f"server:{server.get('id')}",
  50. "name": server.get("openapi", {})
  51. .get("info", {})
  52. .get("title", "Tool Server"),
  53. "meta": {
  54. "description": server.get("openapi", {})
  55. .get("info", {})
  56. .get("description", ""),
  57. },
  58. "access_control": request.app.state.config.TOOL_SERVER_CONNECTIONS[
  59. server.get("idx", 0)
  60. ]
  61. .get("config", {})
  62. .get("access_control", None),
  63. "updated_at": int(time.time()),
  64. "created_at": int(time.time()),
  65. }
  66. )
  67. )
  68. # MCP Tool Servers
  69. for server in request.app.state.config.TOOL_SERVER_CONNECTIONS:
  70. if server.get("type", "openapi") == "mcp":
  71. server_id = server.get("info", {}).get("id")
  72. auth_type = server.get("auth_type", "none")
  73. session_token = None
  74. if auth_type == "oauth_2.1":
  75. splits = server_id.split(":")
  76. server_id = splits[-1] if len(splits) > 1 else server_id
  77. session_token = (
  78. await request.app.state.oauth_client_manager.get_oauth_token(
  79. user.id, f"mcp:{server_id}"
  80. )
  81. )
  82. print("User ID:", user.id)
  83. print("Server ID:", server_id)
  84. print("MCP Session Token:", session_token)
  85. tools.append(
  86. ToolUserResponse(
  87. **{
  88. "id": f"server:mcp:{server.get('info', {}).get('id')}",
  89. "user_id": f"server:mcp:{server.get('info', {}).get('id')}",
  90. "name": server.get("info", {}).get("name", "MCP Tool Server"),
  91. "meta": {
  92. "description": server.get("info", {}).get(
  93. "description", ""
  94. ),
  95. },
  96. "access_control": server.get("config", {}).get(
  97. "access_control", None
  98. ),
  99. "updated_at": int(time.time()),
  100. "created_at": int(time.time()),
  101. **(
  102. {
  103. "authenticated": session_token is not None,
  104. }
  105. if auth_type == "oauth_2.1"
  106. else {}
  107. ),
  108. }
  109. )
  110. )
  111. if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
  112. # Admin can see all tools
  113. return tools
  114. else:
  115. user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)}
  116. tools = [
  117. tool
  118. for tool in tools
  119. if tool.user_id == user.id
  120. or has_access(user.id, "read", tool.access_control, user_group_ids)
  121. ]
  122. return tools
  123. ############################
  124. # GetToolList
  125. ############################
  126. @router.get("/list", response_model=list[ToolUserResponse])
  127. async def get_tool_list(user=Depends(get_verified_user)):
  128. if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
  129. tools = Tools.get_tools()
  130. else:
  131. tools = Tools.get_tools_by_user_id(user.id, "write")
  132. return tools
  133. ############################
  134. # LoadFunctionFromLink
  135. ############################
  136. class LoadUrlForm(BaseModel):
  137. url: HttpUrl
  138. def github_url_to_raw_url(url: str) -> str:
  139. # Handle 'tree' (folder) URLs (add main.py at the end)
  140. m1 = re.match(r"https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.*)", url)
  141. if m1:
  142. org, repo, branch, path = m1.groups()
  143. return f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path.rstrip('/')}/main.py"
  144. # Handle 'blob' (file) URLs
  145. m2 = re.match(r"https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.*)", url)
  146. if m2:
  147. org, repo, branch, path = m2.groups()
  148. return (
  149. f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path}"
  150. )
  151. # No match; return as-is
  152. return url
  153. @router.post("/load/url", response_model=Optional[dict])
  154. async def load_tool_from_url(
  155. request: Request, form_data: LoadUrlForm, user=Depends(get_admin_user)
  156. ):
  157. # NOTE: This is NOT a SSRF vulnerability:
  158. # This endpoint is admin-only (see get_admin_user), meant for *trusted* internal use,
  159. # and does NOT accept untrusted user input. Access is enforced by authentication.
  160. url = str(form_data.url)
  161. if not url:
  162. raise HTTPException(status_code=400, detail="Please enter a valid URL")
  163. url = github_url_to_raw_url(url)
  164. url_parts = url.rstrip("/").split("/")
  165. file_name = url_parts[-1]
  166. tool_name = (
  167. file_name[:-3]
  168. if (
  169. file_name.endswith(".py")
  170. and (not file_name.startswith(("main.py", "index.py", "__init__.py")))
  171. )
  172. else url_parts[-2] if len(url_parts) > 1 else "function"
  173. )
  174. try:
  175. async with aiohttp.ClientSession(trust_env=True) as session:
  176. async with session.get(
  177. url, headers={"Content-Type": "application/json"}
  178. ) as resp:
  179. if resp.status != 200:
  180. raise HTTPException(
  181. status_code=resp.status, detail="Failed to fetch the tool"
  182. )
  183. data = await resp.text()
  184. if not data:
  185. raise HTTPException(
  186. status_code=400, detail="No data received from the URL"
  187. )
  188. return {
  189. "name": tool_name,
  190. "content": data,
  191. }
  192. except Exception as e:
  193. raise HTTPException(status_code=500, detail=f"Error importing tool: {e}")
  194. ############################
  195. # ExportTools
  196. ############################
  197. @router.get("/export", response_model=list[ToolModel])
  198. async def export_tools(user=Depends(get_admin_user)):
  199. tools = Tools.get_tools()
  200. return tools
  201. ############################
  202. # CreateNewTools
  203. ############################
  204. @router.post("/create", response_model=Optional[ToolResponse])
  205. async def create_new_tools(
  206. request: Request,
  207. form_data: ToolForm,
  208. user=Depends(get_verified_user),
  209. ):
  210. if user.role != "admin" and not has_permission(
  211. user.id, "workspace.tools", request.app.state.config.USER_PERMISSIONS
  212. ):
  213. raise HTTPException(
  214. status_code=status.HTTP_401_UNAUTHORIZED,
  215. detail=ERROR_MESSAGES.UNAUTHORIZED,
  216. )
  217. if not form_data.id.isidentifier():
  218. raise HTTPException(
  219. status_code=status.HTTP_400_BAD_REQUEST,
  220. detail="Only alphanumeric characters and underscores are allowed in the id",
  221. )
  222. form_data.id = form_data.id.lower()
  223. tools = Tools.get_tool_by_id(form_data.id)
  224. if tools is None:
  225. try:
  226. form_data.content = replace_imports(form_data.content)
  227. tool_module, frontmatter = load_tool_module_by_id(
  228. form_data.id, content=form_data.content
  229. )
  230. form_data.meta.manifest = frontmatter
  231. TOOLS = request.app.state.TOOLS
  232. TOOLS[form_data.id] = tool_module
  233. specs = get_tool_specs(TOOLS[form_data.id])
  234. tools = Tools.insert_new_tool(user.id, form_data, specs)
  235. tool_cache_dir = CACHE_DIR / "tools" / form_data.id
  236. tool_cache_dir.mkdir(parents=True, exist_ok=True)
  237. if tools:
  238. return tools
  239. else:
  240. raise HTTPException(
  241. status_code=status.HTTP_400_BAD_REQUEST,
  242. detail=ERROR_MESSAGES.DEFAULT("Error creating tools"),
  243. )
  244. except Exception as e:
  245. log.exception(f"Failed to load the tool by id {form_data.id}: {e}")
  246. raise HTTPException(
  247. status_code=status.HTTP_400_BAD_REQUEST,
  248. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  249. )
  250. else:
  251. raise HTTPException(
  252. status_code=status.HTTP_400_BAD_REQUEST,
  253. detail=ERROR_MESSAGES.ID_TAKEN,
  254. )
  255. ############################
  256. # GetToolsById
  257. ############################
  258. @router.get("/id/{id}", response_model=Optional[ToolModel])
  259. async def get_tools_by_id(id: str, user=Depends(get_verified_user)):
  260. tools = Tools.get_tool_by_id(id)
  261. if tools:
  262. if (
  263. user.role == "admin"
  264. or tools.user_id == user.id
  265. or has_access(user.id, "read", tools.access_control)
  266. ):
  267. return tools
  268. else:
  269. raise HTTPException(
  270. status_code=status.HTTP_401_UNAUTHORIZED,
  271. detail=ERROR_MESSAGES.NOT_FOUND,
  272. )
  273. ############################
  274. # UpdateToolsById
  275. ############################
  276. @router.post("/id/{id}/update", response_model=Optional[ToolModel])
  277. async def update_tools_by_id(
  278. request: Request,
  279. id: str,
  280. form_data: ToolForm,
  281. user=Depends(get_verified_user),
  282. ):
  283. tools = Tools.get_tool_by_id(id)
  284. if not tools:
  285. raise HTTPException(
  286. status_code=status.HTTP_401_UNAUTHORIZED,
  287. detail=ERROR_MESSAGES.NOT_FOUND,
  288. )
  289. # Is the user the original creator, in a group with write access, or an admin
  290. if (
  291. tools.user_id != user.id
  292. and not has_access(user.id, "write", tools.access_control)
  293. and user.role != "admin"
  294. ):
  295. raise HTTPException(
  296. status_code=status.HTTP_401_UNAUTHORIZED,
  297. detail=ERROR_MESSAGES.UNAUTHORIZED,
  298. )
  299. try:
  300. form_data.content = replace_imports(form_data.content)
  301. tool_module, frontmatter = load_tool_module_by_id(id, content=form_data.content)
  302. form_data.meta.manifest = frontmatter
  303. TOOLS = request.app.state.TOOLS
  304. TOOLS[id] = tool_module
  305. specs = get_tool_specs(TOOLS[id])
  306. updated = {
  307. **form_data.model_dump(exclude={"id"}),
  308. "specs": specs,
  309. }
  310. log.debug(updated)
  311. tools = Tools.update_tool_by_id(id, updated)
  312. if tools:
  313. return tools
  314. else:
  315. raise HTTPException(
  316. status_code=status.HTTP_400_BAD_REQUEST,
  317. detail=ERROR_MESSAGES.DEFAULT("Error updating tools"),
  318. )
  319. except Exception as e:
  320. raise HTTPException(
  321. status_code=status.HTTP_400_BAD_REQUEST,
  322. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  323. )
  324. ############################
  325. # DeleteToolsById
  326. ############################
  327. @router.delete("/id/{id}/delete", response_model=bool)
  328. async def delete_tools_by_id(
  329. request: Request, id: str, user=Depends(get_verified_user)
  330. ):
  331. tools = Tools.get_tool_by_id(id)
  332. if not tools:
  333. raise HTTPException(
  334. status_code=status.HTTP_401_UNAUTHORIZED,
  335. detail=ERROR_MESSAGES.NOT_FOUND,
  336. )
  337. if (
  338. tools.user_id != user.id
  339. and not has_access(user.id, "write", tools.access_control)
  340. and user.role != "admin"
  341. ):
  342. raise HTTPException(
  343. status_code=status.HTTP_401_UNAUTHORIZED,
  344. detail=ERROR_MESSAGES.UNAUTHORIZED,
  345. )
  346. result = Tools.delete_tool_by_id(id)
  347. if result:
  348. TOOLS = request.app.state.TOOLS
  349. if id in TOOLS:
  350. del TOOLS[id]
  351. return result
  352. ############################
  353. # GetToolValves
  354. ############################
  355. @router.get("/id/{id}/valves", response_model=Optional[dict])
  356. async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
  357. tools = Tools.get_tool_by_id(id)
  358. if tools:
  359. try:
  360. valves = Tools.get_tool_valves_by_id(id)
  361. return valves
  362. except Exception as e:
  363. raise HTTPException(
  364. status_code=status.HTTP_400_BAD_REQUEST,
  365. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  366. )
  367. else:
  368. raise HTTPException(
  369. status_code=status.HTTP_401_UNAUTHORIZED,
  370. detail=ERROR_MESSAGES.NOT_FOUND,
  371. )
  372. ############################
  373. # GetToolValvesSpec
  374. ############################
  375. @router.get("/id/{id}/valves/spec", response_model=Optional[dict])
  376. async def get_tools_valves_spec_by_id(
  377. request: Request, id: str, user=Depends(get_verified_user)
  378. ):
  379. tools = Tools.get_tool_by_id(id)
  380. if tools:
  381. if id in request.app.state.TOOLS:
  382. tools_module = request.app.state.TOOLS[id]
  383. else:
  384. tools_module, _ = load_tool_module_by_id(id)
  385. request.app.state.TOOLS[id] = tools_module
  386. if hasattr(tools_module, "Valves"):
  387. Valves = tools_module.Valves
  388. return Valves.schema()
  389. return None
  390. else:
  391. raise HTTPException(
  392. status_code=status.HTTP_401_UNAUTHORIZED,
  393. detail=ERROR_MESSAGES.NOT_FOUND,
  394. )
  395. ############################
  396. # UpdateToolValves
  397. ############################
  398. @router.post("/id/{id}/valves/update", response_model=Optional[dict])
  399. async def update_tools_valves_by_id(
  400. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  401. ):
  402. tools = Tools.get_tool_by_id(id)
  403. if not tools:
  404. raise HTTPException(
  405. status_code=status.HTTP_401_UNAUTHORIZED,
  406. detail=ERROR_MESSAGES.NOT_FOUND,
  407. )
  408. if (
  409. tools.user_id != user.id
  410. and not has_access(user.id, "write", tools.access_control)
  411. and user.role != "admin"
  412. ):
  413. raise HTTPException(
  414. status_code=status.HTTP_400_BAD_REQUEST,
  415. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  416. )
  417. if id in request.app.state.TOOLS:
  418. tools_module = request.app.state.TOOLS[id]
  419. else:
  420. tools_module, _ = load_tool_module_by_id(id)
  421. request.app.state.TOOLS[id] = tools_module
  422. if not hasattr(tools_module, "Valves"):
  423. raise HTTPException(
  424. status_code=status.HTTP_401_UNAUTHORIZED,
  425. detail=ERROR_MESSAGES.NOT_FOUND,
  426. )
  427. Valves = tools_module.Valves
  428. try:
  429. form_data = {k: v for k, v in form_data.items() if v is not None}
  430. valves = Valves(**form_data)
  431. valves_dict = valves.model_dump(exclude_unset=True)
  432. Tools.update_tool_valves_by_id(id, valves_dict)
  433. return valves_dict
  434. except Exception as e:
  435. log.exception(f"Failed to update tool valves by id {id}: {e}")
  436. raise HTTPException(
  437. status_code=status.HTTP_400_BAD_REQUEST,
  438. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  439. )
  440. ############################
  441. # ToolUserValves
  442. ############################
  443. @router.get("/id/{id}/valves/user", response_model=Optional[dict])
  444. async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
  445. tools = Tools.get_tool_by_id(id)
  446. if tools:
  447. try:
  448. user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
  449. return user_valves
  450. except Exception as e:
  451. raise HTTPException(
  452. status_code=status.HTTP_400_BAD_REQUEST,
  453. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  454. )
  455. else:
  456. raise HTTPException(
  457. status_code=status.HTTP_401_UNAUTHORIZED,
  458. detail=ERROR_MESSAGES.NOT_FOUND,
  459. )
  460. @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
  461. async def get_tools_user_valves_spec_by_id(
  462. request: Request, id: str, user=Depends(get_verified_user)
  463. ):
  464. tools = Tools.get_tool_by_id(id)
  465. if tools:
  466. if id in request.app.state.TOOLS:
  467. tools_module = request.app.state.TOOLS[id]
  468. else:
  469. tools_module, _ = load_tool_module_by_id(id)
  470. request.app.state.TOOLS[id] = tools_module
  471. if hasattr(tools_module, "UserValves"):
  472. UserValves = tools_module.UserValves
  473. return UserValves.schema()
  474. return None
  475. else:
  476. raise HTTPException(
  477. status_code=status.HTTP_401_UNAUTHORIZED,
  478. detail=ERROR_MESSAGES.NOT_FOUND,
  479. )
  480. @router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
  481. async def update_tools_user_valves_by_id(
  482. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  483. ):
  484. tools = Tools.get_tool_by_id(id)
  485. if tools:
  486. if id in request.app.state.TOOLS:
  487. tools_module = request.app.state.TOOLS[id]
  488. else:
  489. tools_module, _ = load_tool_module_by_id(id)
  490. request.app.state.TOOLS[id] = tools_module
  491. if hasattr(tools_module, "UserValves"):
  492. UserValves = tools_module.UserValves
  493. try:
  494. form_data = {k: v for k, v in form_data.items() if v is not None}
  495. user_valves = UserValves(**form_data)
  496. user_valves_dict = user_valves.model_dump(exclude_unset=True)
  497. Tools.update_user_valves_by_id_and_user_id(
  498. id, user.id, user_valves_dict
  499. )
  500. return user_valves_dict
  501. except Exception as e:
  502. log.exception(f"Failed to update user valves by id {id}: {e}")
  503. raise HTTPException(
  504. status_code=status.HTTP_400_BAD_REQUEST,
  505. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  506. )
  507. else:
  508. raise HTTPException(
  509. status_code=status.HTTP_401_UNAUTHORIZED,
  510. detail=ERROR_MESSAGES.NOT_FOUND,
  511. )
  512. else:
  513. raise HTTPException(
  514. status_code=status.HTTP_401_UNAUTHORIZED,
  515. detail=ERROR_MESSAGES.NOT_FOUND,
  516. )