1
0

tools.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584
  1. import logging
  2. from pathlib import Path
  3. from typing import Optional
  4. import time
  5. import re
  6. import aiohttp
  7. from open_webui.models.groups import Groups
  8. from pydantic import BaseModel, HttpUrl
  9. from fastapi import APIRouter, Depends, HTTPException, Request, status
  10. from open_webui.models.tools import (
  11. ToolForm,
  12. ToolModel,
  13. ToolResponse,
  14. ToolUserResponse,
  15. Tools,
  16. )
  17. from open_webui.utils.plugin import load_tool_module_by_id, replace_imports
  18. from open_webui.utils.tools import get_tool_specs
  19. from open_webui.utils.auth import get_admin_user, get_verified_user
  20. from open_webui.utils.access_control import has_access, has_permission
  21. from open_webui.utils.tools import get_tool_servers
  22. from open_webui.env import SRC_LOG_LEVELS
  23. from open_webui.config import CACHE_DIR, BYPASS_ADMIN_ACCESS_CONTROL
  24. from open_webui.constants import ERROR_MESSAGES
  25. log = logging.getLogger(__name__)
  26. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  27. router = APIRouter()
  28. ############################
  29. # GetTools
  30. ############################
  31. @router.get("/", response_model=list[ToolUserResponse])
  32. async def get_tools(request: Request, user=Depends(get_verified_user)):
  33. tools = Tools.get_tools()
  34. # OpenAPI Tool Servers
  35. for server in await get_tool_servers(request):
  36. tools.append(
  37. ToolUserResponse(
  38. **{
  39. "id": f"server:{server.get('id')}",
  40. "user_id": f"server:{server.get('id')}",
  41. "name": server.get("openapi", {})
  42. .get("info", {})
  43. .get("title", "Tool Server"),
  44. "meta": {
  45. "description": server.get("openapi", {})
  46. .get("info", {})
  47. .get("description", ""),
  48. },
  49. "access_control": request.app.state.config.TOOL_SERVER_CONNECTIONS[
  50. server.get("idx", 0)
  51. ]
  52. .get("config", {})
  53. .get("access_control", None),
  54. "updated_at": int(time.time()),
  55. "created_at": int(time.time()),
  56. }
  57. )
  58. )
  59. # MCP Tool Servers
  60. for server in request.app.state.config.TOOL_SERVER_CONNECTIONS:
  61. if server.get("type", "openapi") == "mcp":
  62. tools.append(
  63. ToolUserResponse(
  64. **{
  65. "id": f"server:mcp:{server.get('info', {}).get('id')}",
  66. "user_id": f"server:mcp:{server.get('info', {}).get('id')}",
  67. "name": server.get("info", {}).get("name", "MCP Tool Server"),
  68. "meta": {
  69. "description": server.get("info", {}).get(
  70. "description", ""
  71. ),
  72. },
  73. "access_control": server.get("config", {}).get(
  74. "access_control", None
  75. ),
  76. "updated_at": int(time.time()),
  77. "created_at": int(time.time()),
  78. }
  79. )
  80. )
  81. if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
  82. # Admin can see all tools
  83. return tools
  84. else:
  85. user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)}
  86. tools = [
  87. tool
  88. for tool in tools
  89. if tool.user_id == user.id
  90. or has_access(user.id, "read", tool.access_control, user_group_ids)
  91. ]
  92. return tools
  93. ############################
  94. # GetToolList
  95. ############################
  96. @router.get("/list", response_model=list[ToolUserResponse])
  97. async def get_tool_list(user=Depends(get_verified_user)):
  98. if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
  99. tools = Tools.get_tools()
  100. else:
  101. tools = Tools.get_tools_by_user_id(user.id, "write")
  102. return tools
  103. ############################
  104. # LoadFunctionFromLink
  105. ############################
  106. class LoadUrlForm(BaseModel):
  107. url: HttpUrl
  108. def github_url_to_raw_url(url: str) -> str:
  109. # Handle 'tree' (folder) URLs (add main.py at the end)
  110. m1 = re.match(r"https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.*)", url)
  111. if m1:
  112. org, repo, branch, path = m1.groups()
  113. return f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path.rstrip('/')}/main.py"
  114. # Handle 'blob' (file) URLs
  115. m2 = re.match(r"https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.*)", url)
  116. if m2:
  117. org, repo, branch, path = m2.groups()
  118. return (
  119. f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path}"
  120. )
  121. # No match; return as-is
  122. return url
  123. @router.post("/load/url", response_model=Optional[dict])
  124. async def load_tool_from_url(
  125. request: Request, form_data: LoadUrlForm, user=Depends(get_admin_user)
  126. ):
  127. # NOTE: This is NOT a SSRF vulnerability:
  128. # This endpoint is admin-only (see get_admin_user), meant for *trusted* internal use,
  129. # and does NOT accept untrusted user input. Access is enforced by authentication.
  130. url = str(form_data.url)
  131. if not url:
  132. raise HTTPException(status_code=400, detail="Please enter a valid URL")
  133. url = github_url_to_raw_url(url)
  134. url_parts = url.rstrip("/").split("/")
  135. file_name = url_parts[-1]
  136. tool_name = (
  137. file_name[:-3]
  138. if (
  139. file_name.endswith(".py")
  140. and (not file_name.startswith(("main.py", "index.py", "__init__.py")))
  141. )
  142. else url_parts[-2] if len(url_parts) > 1 else "function"
  143. )
  144. try:
  145. async with aiohttp.ClientSession(trust_env=True) as session:
  146. async with session.get(
  147. url, headers={"Content-Type": "application/json"}
  148. ) as resp:
  149. if resp.status != 200:
  150. raise HTTPException(
  151. status_code=resp.status, detail="Failed to fetch the tool"
  152. )
  153. data = await resp.text()
  154. if not data:
  155. raise HTTPException(
  156. status_code=400, detail="No data received from the URL"
  157. )
  158. return {
  159. "name": tool_name,
  160. "content": data,
  161. }
  162. except Exception as e:
  163. raise HTTPException(status_code=500, detail=f"Error importing tool: {e}")
  164. ############################
  165. # ExportTools
  166. ############################
  167. @router.get("/export", response_model=list[ToolModel])
  168. async def export_tools(user=Depends(get_admin_user)):
  169. tools = Tools.get_tools()
  170. return tools
  171. ############################
  172. # CreateNewTools
  173. ############################
  174. @router.post("/create", response_model=Optional[ToolResponse])
  175. async def create_new_tools(
  176. request: Request,
  177. form_data: ToolForm,
  178. user=Depends(get_verified_user),
  179. ):
  180. if user.role != "admin" and not has_permission(
  181. user.id, "workspace.tools", request.app.state.config.USER_PERMISSIONS
  182. ):
  183. raise HTTPException(
  184. status_code=status.HTTP_401_UNAUTHORIZED,
  185. detail=ERROR_MESSAGES.UNAUTHORIZED,
  186. )
  187. if not form_data.id.isidentifier():
  188. raise HTTPException(
  189. status_code=status.HTTP_400_BAD_REQUEST,
  190. detail="Only alphanumeric characters and underscores are allowed in the id",
  191. )
  192. form_data.id = form_data.id.lower()
  193. tools = Tools.get_tool_by_id(form_data.id)
  194. if tools is None:
  195. try:
  196. form_data.content = replace_imports(form_data.content)
  197. tool_module, frontmatter = load_tool_module_by_id(
  198. form_data.id, content=form_data.content
  199. )
  200. form_data.meta.manifest = frontmatter
  201. TOOLS = request.app.state.TOOLS
  202. TOOLS[form_data.id] = tool_module
  203. specs = get_tool_specs(TOOLS[form_data.id])
  204. tools = Tools.insert_new_tool(user.id, form_data, specs)
  205. tool_cache_dir = CACHE_DIR / "tools" / form_data.id
  206. tool_cache_dir.mkdir(parents=True, exist_ok=True)
  207. if tools:
  208. return tools
  209. else:
  210. raise HTTPException(
  211. status_code=status.HTTP_400_BAD_REQUEST,
  212. detail=ERROR_MESSAGES.DEFAULT("Error creating tools"),
  213. )
  214. except Exception as e:
  215. log.exception(f"Failed to load the tool by id {form_data.id}: {e}")
  216. raise HTTPException(
  217. status_code=status.HTTP_400_BAD_REQUEST,
  218. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  219. )
  220. else:
  221. raise HTTPException(
  222. status_code=status.HTTP_400_BAD_REQUEST,
  223. detail=ERROR_MESSAGES.ID_TAKEN,
  224. )
  225. ############################
  226. # GetToolsById
  227. ############################
  228. @router.get("/id/{id}", response_model=Optional[ToolModel])
  229. async def get_tools_by_id(id: str, user=Depends(get_verified_user)):
  230. tools = Tools.get_tool_by_id(id)
  231. if tools:
  232. if (
  233. user.role == "admin"
  234. or tools.user_id == user.id
  235. or has_access(user.id, "read", tools.access_control)
  236. ):
  237. return tools
  238. else:
  239. raise HTTPException(
  240. status_code=status.HTTP_401_UNAUTHORIZED,
  241. detail=ERROR_MESSAGES.NOT_FOUND,
  242. )
  243. ############################
  244. # UpdateToolsById
  245. ############################
  246. @router.post("/id/{id}/update", response_model=Optional[ToolModel])
  247. async def update_tools_by_id(
  248. request: Request,
  249. id: str,
  250. form_data: ToolForm,
  251. user=Depends(get_verified_user),
  252. ):
  253. tools = Tools.get_tool_by_id(id)
  254. if not tools:
  255. raise HTTPException(
  256. status_code=status.HTTP_401_UNAUTHORIZED,
  257. detail=ERROR_MESSAGES.NOT_FOUND,
  258. )
  259. # Is the user the original creator, in a group with write access, or an admin
  260. if (
  261. tools.user_id != user.id
  262. and not has_access(user.id, "write", tools.access_control)
  263. and user.role != "admin"
  264. ):
  265. raise HTTPException(
  266. status_code=status.HTTP_401_UNAUTHORIZED,
  267. detail=ERROR_MESSAGES.UNAUTHORIZED,
  268. )
  269. try:
  270. form_data.content = replace_imports(form_data.content)
  271. tool_module, frontmatter = load_tool_module_by_id(id, content=form_data.content)
  272. form_data.meta.manifest = frontmatter
  273. TOOLS = request.app.state.TOOLS
  274. TOOLS[id] = tool_module
  275. specs = get_tool_specs(TOOLS[id])
  276. updated = {
  277. **form_data.model_dump(exclude={"id"}),
  278. "specs": specs,
  279. }
  280. log.debug(updated)
  281. tools = Tools.update_tool_by_id(id, updated)
  282. if tools:
  283. return tools
  284. else:
  285. raise HTTPException(
  286. status_code=status.HTTP_400_BAD_REQUEST,
  287. detail=ERROR_MESSAGES.DEFAULT("Error updating tools"),
  288. )
  289. except Exception as e:
  290. raise HTTPException(
  291. status_code=status.HTTP_400_BAD_REQUEST,
  292. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  293. )
  294. ############################
  295. # DeleteToolsById
  296. ############################
  297. @router.delete("/id/{id}/delete", response_model=bool)
  298. async def delete_tools_by_id(
  299. request: Request, id: str, user=Depends(get_verified_user)
  300. ):
  301. tools = Tools.get_tool_by_id(id)
  302. if not tools:
  303. raise HTTPException(
  304. status_code=status.HTTP_401_UNAUTHORIZED,
  305. detail=ERROR_MESSAGES.NOT_FOUND,
  306. )
  307. if (
  308. tools.user_id != user.id
  309. and not has_access(user.id, "write", tools.access_control)
  310. and user.role != "admin"
  311. ):
  312. raise HTTPException(
  313. status_code=status.HTTP_401_UNAUTHORIZED,
  314. detail=ERROR_MESSAGES.UNAUTHORIZED,
  315. )
  316. result = Tools.delete_tool_by_id(id)
  317. if result:
  318. TOOLS = request.app.state.TOOLS
  319. if id in TOOLS:
  320. del TOOLS[id]
  321. return result
  322. ############################
  323. # GetToolValves
  324. ############################
  325. @router.get("/id/{id}/valves", response_model=Optional[dict])
  326. async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
  327. tools = Tools.get_tool_by_id(id)
  328. if tools:
  329. try:
  330. valves = Tools.get_tool_valves_by_id(id)
  331. return valves
  332. except Exception as e:
  333. raise HTTPException(
  334. status_code=status.HTTP_400_BAD_REQUEST,
  335. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  336. )
  337. else:
  338. raise HTTPException(
  339. status_code=status.HTTP_401_UNAUTHORIZED,
  340. detail=ERROR_MESSAGES.NOT_FOUND,
  341. )
  342. ############################
  343. # GetToolValvesSpec
  344. ############################
  345. @router.get("/id/{id}/valves/spec", response_model=Optional[dict])
  346. async def get_tools_valves_spec_by_id(
  347. request: Request, id: str, user=Depends(get_verified_user)
  348. ):
  349. tools = Tools.get_tool_by_id(id)
  350. if tools:
  351. if id in request.app.state.TOOLS:
  352. tools_module = request.app.state.TOOLS[id]
  353. else:
  354. tools_module, _ = load_tool_module_by_id(id)
  355. request.app.state.TOOLS[id] = tools_module
  356. if hasattr(tools_module, "Valves"):
  357. Valves = tools_module.Valves
  358. return Valves.schema()
  359. return None
  360. else:
  361. raise HTTPException(
  362. status_code=status.HTTP_401_UNAUTHORIZED,
  363. detail=ERROR_MESSAGES.NOT_FOUND,
  364. )
  365. ############################
  366. # UpdateToolValves
  367. ############################
  368. @router.post("/id/{id}/valves/update", response_model=Optional[dict])
  369. async def update_tools_valves_by_id(
  370. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  371. ):
  372. tools = Tools.get_tool_by_id(id)
  373. if not tools:
  374. raise HTTPException(
  375. status_code=status.HTTP_401_UNAUTHORIZED,
  376. detail=ERROR_MESSAGES.NOT_FOUND,
  377. )
  378. if (
  379. tools.user_id != user.id
  380. and not has_access(user.id, "write", tools.access_control)
  381. and user.role != "admin"
  382. ):
  383. raise HTTPException(
  384. status_code=status.HTTP_400_BAD_REQUEST,
  385. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  386. )
  387. if id in request.app.state.TOOLS:
  388. tools_module = request.app.state.TOOLS[id]
  389. else:
  390. tools_module, _ = load_tool_module_by_id(id)
  391. request.app.state.TOOLS[id] = tools_module
  392. if not hasattr(tools_module, "Valves"):
  393. raise HTTPException(
  394. status_code=status.HTTP_401_UNAUTHORIZED,
  395. detail=ERROR_MESSAGES.NOT_FOUND,
  396. )
  397. Valves = tools_module.Valves
  398. try:
  399. form_data = {k: v for k, v in form_data.items() if v is not None}
  400. valves = Valves(**form_data)
  401. Tools.update_tool_valves_by_id(id, valves.model_dump())
  402. return valves.model_dump()
  403. except Exception as e:
  404. log.exception(f"Failed to update tool valves by id {id}: {e}")
  405. raise HTTPException(
  406. status_code=status.HTTP_400_BAD_REQUEST,
  407. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  408. )
  409. ############################
  410. # ToolUserValves
  411. ############################
  412. @router.get("/id/{id}/valves/user", response_model=Optional[dict])
  413. async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
  414. tools = Tools.get_tool_by_id(id)
  415. if tools:
  416. try:
  417. user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
  418. return user_valves
  419. except Exception as e:
  420. raise HTTPException(
  421. status_code=status.HTTP_400_BAD_REQUEST,
  422. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  423. )
  424. else:
  425. raise HTTPException(
  426. status_code=status.HTTP_401_UNAUTHORIZED,
  427. detail=ERROR_MESSAGES.NOT_FOUND,
  428. )
  429. @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
  430. async def get_tools_user_valves_spec_by_id(
  431. request: Request, id: str, user=Depends(get_verified_user)
  432. ):
  433. tools = Tools.get_tool_by_id(id)
  434. if tools:
  435. if id in request.app.state.TOOLS:
  436. tools_module = request.app.state.TOOLS[id]
  437. else:
  438. tools_module, _ = load_tool_module_by_id(id)
  439. request.app.state.TOOLS[id] = tools_module
  440. if hasattr(tools_module, "UserValves"):
  441. UserValves = tools_module.UserValves
  442. return UserValves.schema()
  443. return None
  444. else:
  445. raise HTTPException(
  446. status_code=status.HTTP_401_UNAUTHORIZED,
  447. detail=ERROR_MESSAGES.NOT_FOUND,
  448. )
  449. @router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
  450. async def update_tools_user_valves_by_id(
  451. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  452. ):
  453. tools = Tools.get_tool_by_id(id)
  454. if tools:
  455. if id in request.app.state.TOOLS:
  456. tools_module = request.app.state.TOOLS[id]
  457. else:
  458. tools_module, _ = load_tool_module_by_id(id)
  459. request.app.state.TOOLS[id] = tools_module
  460. if hasattr(tools_module, "UserValves"):
  461. UserValves = tools_module.UserValves
  462. try:
  463. form_data = {k: v for k, v in form_data.items() if v is not None}
  464. user_valves = UserValves(**form_data)
  465. Tools.update_user_valves_by_id_and_user_id(
  466. id, user.id, user_valves.model_dump()
  467. )
  468. return user_valves.model_dump()
  469. except Exception as e:
  470. log.exception(f"Failed to update user valves by id {id}: {e}")
  471. raise HTTPException(
  472. status_code=status.HTTP_400_BAD_REQUEST,
  473. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  474. )
  475. else:
  476. raise HTTPException(
  477. status_code=status.HTTP_401_UNAUTHORIZED,
  478. detail=ERROR_MESSAGES.NOT_FOUND,
  479. )
  480. else:
  481. raise HTTPException(
  482. status_code=status.HTTP_401_UNAUTHORIZED,
  483. detail=ERROR_MESSAGES.NOT_FOUND,
  484. )