tools.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594
  1. import logging
  2. from pathlib import Path
  3. from typing import Optional
  4. import time
  5. import re
  6. import aiohttp
  7. from open_webui.models.groups import Groups
  8. from pydantic import BaseModel, HttpUrl
  9. from fastapi import APIRouter, Depends, HTTPException, Request, status
  10. from open_webui.models.tools import (
  11. ToolForm,
  12. ToolModel,
  13. ToolResponse,
  14. ToolUserResponse,
  15. Tools,
  16. )
  17. from open_webui.utils.plugin import load_tool_module_by_id, replace_imports
  18. from open_webui.utils.tools import get_tool_specs
  19. from open_webui.utils.auth import get_admin_user, get_verified_user
  20. from open_webui.utils.access_control import has_access, has_permission
  21. from open_webui.utils.tools import get_tool_servers
  22. from open_webui.env import SRC_LOG_LEVELS
  23. from open_webui.config import CACHE_DIR, BYPASS_ADMIN_ACCESS_CONTROL
  24. from open_webui.constants import ERROR_MESSAGES
  25. log = logging.getLogger(__name__)
  26. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  27. router = APIRouter()
  28. ############################
  29. # GetTools
  30. ############################
  31. @router.get("/", response_model=list[ToolUserResponse])
  32. async def get_tools(request: Request, user=Depends(get_verified_user)):
  33. tools = [
  34. ToolUserResponse(
  35. **{
  36. **tool.model_dump(),
  37. "has_user_valves": "class UserValves(BaseModel):" in tool.content,
  38. }
  39. )
  40. for tool in Tools.get_tools()
  41. ]
  42. # OpenAPI Tool Servers
  43. for server in await get_tool_servers(request):
  44. tools.append(
  45. ToolUserResponse(
  46. **{
  47. "id": f"server:{server.get('id')}",
  48. "user_id": f"server:{server.get('id')}",
  49. "name": server.get("openapi", {})
  50. .get("info", {})
  51. .get("title", "Tool Server"),
  52. "meta": {
  53. "description": server.get("openapi", {})
  54. .get("info", {})
  55. .get("description", ""),
  56. },
  57. "access_control": request.app.state.config.TOOL_SERVER_CONNECTIONS[
  58. server.get("idx", 0)
  59. ]
  60. .get("config", {})
  61. .get("access_control", None),
  62. "updated_at": int(time.time()),
  63. "created_at": int(time.time()),
  64. }
  65. )
  66. )
  67. # MCP Tool Servers
  68. for server in request.app.state.config.TOOL_SERVER_CONNECTIONS:
  69. if server.get("type", "openapi") == "mcp":
  70. tools.append(
  71. ToolUserResponse(
  72. **{
  73. "id": f"server:mcp:{server.get('info', {}).get('id')}",
  74. "user_id": f"server:mcp:{server.get('info', {}).get('id')}",
  75. "name": server.get("info", {}).get("name", "MCP Tool Server"),
  76. "meta": {
  77. "description": server.get("info", {}).get(
  78. "description", ""
  79. ),
  80. },
  81. "access_control": server.get("config", {}).get(
  82. "access_control", None
  83. ),
  84. "updated_at": int(time.time()),
  85. "created_at": int(time.time()),
  86. }
  87. )
  88. )
  89. if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
  90. # Admin can see all tools
  91. return tools
  92. else:
  93. user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)}
  94. tools = [
  95. tool
  96. for tool in tools
  97. if tool.user_id == user.id
  98. or has_access(user.id, "read", tool.access_control, user_group_ids)
  99. ]
  100. return tools
  101. ############################
  102. # GetToolList
  103. ############################
  104. @router.get("/list", response_model=list[ToolUserResponse])
  105. async def get_tool_list(user=Depends(get_verified_user)):
  106. if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
  107. tools = Tools.get_tools()
  108. else:
  109. tools = Tools.get_tools_by_user_id(user.id, "write")
  110. return tools
  111. ############################
  112. # LoadFunctionFromLink
  113. ############################
  114. class LoadUrlForm(BaseModel):
  115. url: HttpUrl
  116. def github_url_to_raw_url(url: str) -> str:
  117. # Handle 'tree' (folder) URLs (add main.py at the end)
  118. m1 = re.match(r"https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.*)", url)
  119. if m1:
  120. org, repo, branch, path = m1.groups()
  121. return f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path.rstrip('/')}/main.py"
  122. # Handle 'blob' (file) URLs
  123. m2 = re.match(r"https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.*)", url)
  124. if m2:
  125. org, repo, branch, path = m2.groups()
  126. return (
  127. f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path}"
  128. )
  129. # No match; return as-is
  130. return url
  131. @router.post("/load/url", response_model=Optional[dict])
  132. async def load_tool_from_url(
  133. request: Request, form_data: LoadUrlForm, user=Depends(get_admin_user)
  134. ):
  135. # NOTE: This is NOT a SSRF vulnerability:
  136. # This endpoint is admin-only (see get_admin_user), meant for *trusted* internal use,
  137. # and does NOT accept untrusted user input. Access is enforced by authentication.
  138. url = str(form_data.url)
  139. if not url:
  140. raise HTTPException(status_code=400, detail="Please enter a valid URL")
  141. url = github_url_to_raw_url(url)
  142. url_parts = url.rstrip("/").split("/")
  143. file_name = url_parts[-1]
  144. tool_name = (
  145. file_name[:-3]
  146. if (
  147. file_name.endswith(".py")
  148. and (not file_name.startswith(("main.py", "index.py", "__init__.py")))
  149. )
  150. else url_parts[-2] if len(url_parts) > 1 else "function"
  151. )
  152. try:
  153. async with aiohttp.ClientSession(trust_env=True) as session:
  154. async with session.get(
  155. url, headers={"Content-Type": "application/json"}
  156. ) as resp:
  157. if resp.status != 200:
  158. raise HTTPException(
  159. status_code=resp.status, detail="Failed to fetch the tool"
  160. )
  161. data = await resp.text()
  162. if not data:
  163. raise HTTPException(
  164. status_code=400, detail="No data received from the URL"
  165. )
  166. return {
  167. "name": tool_name,
  168. "content": data,
  169. }
  170. except Exception as e:
  171. raise HTTPException(status_code=500, detail=f"Error importing tool: {e}")
  172. ############################
  173. # ExportTools
  174. ############################
  175. @router.get("/export", response_model=list[ToolModel])
  176. async def export_tools(user=Depends(get_admin_user)):
  177. tools = Tools.get_tools()
  178. return tools
  179. ############################
  180. # CreateNewTools
  181. ############################
  182. @router.post("/create", response_model=Optional[ToolResponse])
  183. async def create_new_tools(
  184. request: Request,
  185. form_data: ToolForm,
  186. user=Depends(get_verified_user),
  187. ):
  188. if user.role != "admin" and not has_permission(
  189. user.id, "workspace.tools", request.app.state.config.USER_PERMISSIONS
  190. ):
  191. raise HTTPException(
  192. status_code=status.HTTP_401_UNAUTHORIZED,
  193. detail=ERROR_MESSAGES.UNAUTHORIZED,
  194. )
  195. if not form_data.id.isidentifier():
  196. raise HTTPException(
  197. status_code=status.HTTP_400_BAD_REQUEST,
  198. detail="Only alphanumeric characters and underscores are allowed in the id",
  199. )
  200. form_data.id = form_data.id.lower()
  201. tools = Tools.get_tool_by_id(form_data.id)
  202. if tools is None:
  203. try:
  204. form_data.content = replace_imports(form_data.content)
  205. tool_module, frontmatter = load_tool_module_by_id(
  206. form_data.id, content=form_data.content
  207. )
  208. form_data.meta.manifest = frontmatter
  209. TOOLS = request.app.state.TOOLS
  210. TOOLS[form_data.id] = tool_module
  211. specs = get_tool_specs(TOOLS[form_data.id])
  212. tools = Tools.insert_new_tool(user.id, form_data, specs)
  213. tool_cache_dir = CACHE_DIR / "tools" / form_data.id
  214. tool_cache_dir.mkdir(parents=True, exist_ok=True)
  215. if tools:
  216. return tools
  217. else:
  218. raise HTTPException(
  219. status_code=status.HTTP_400_BAD_REQUEST,
  220. detail=ERROR_MESSAGES.DEFAULT("Error creating tools"),
  221. )
  222. except Exception as e:
  223. log.exception(f"Failed to load the tool by id {form_data.id}: {e}")
  224. raise HTTPException(
  225. status_code=status.HTTP_400_BAD_REQUEST,
  226. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  227. )
  228. else:
  229. raise HTTPException(
  230. status_code=status.HTTP_400_BAD_REQUEST,
  231. detail=ERROR_MESSAGES.ID_TAKEN,
  232. )
  233. ############################
  234. # GetToolsById
  235. ############################
  236. @router.get("/id/{id}", response_model=Optional[ToolModel])
  237. async def get_tools_by_id(id: str, user=Depends(get_verified_user)):
  238. tools = Tools.get_tool_by_id(id)
  239. if tools:
  240. if (
  241. user.role == "admin"
  242. or tools.user_id == user.id
  243. or has_access(user.id, "read", tools.access_control)
  244. ):
  245. return tools
  246. else:
  247. raise HTTPException(
  248. status_code=status.HTTP_401_UNAUTHORIZED,
  249. detail=ERROR_MESSAGES.NOT_FOUND,
  250. )
  251. ############################
  252. # UpdateToolsById
  253. ############################
  254. @router.post("/id/{id}/update", response_model=Optional[ToolModel])
  255. async def update_tools_by_id(
  256. request: Request,
  257. id: str,
  258. form_data: ToolForm,
  259. user=Depends(get_verified_user),
  260. ):
  261. tools = Tools.get_tool_by_id(id)
  262. if not tools:
  263. raise HTTPException(
  264. status_code=status.HTTP_401_UNAUTHORIZED,
  265. detail=ERROR_MESSAGES.NOT_FOUND,
  266. )
  267. # Is the user the original creator, in a group with write access, or an admin
  268. if (
  269. tools.user_id != user.id
  270. and not has_access(user.id, "write", tools.access_control)
  271. and user.role != "admin"
  272. ):
  273. raise HTTPException(
  274. status_code=status.HTTP_401_UNAUTHORIZED,
  275. detail=ERROR_MESSAGES.UNAUTHORIZED,
  276. )
  277. try:
  278. form_data.content = replace_imports(form_data.content)
  279. tool_module, frontmatter = load_tool_module_by_id(id, content=form_data.content)
  280. form_data.meta.manifest = frontmatter
  281. TOOLS = request.app.state.TOOLS
  282. TOOLS[id] = tool_module
  283. specs = get_tool_specs(TOOLS[id])
  284. updated = {
  285. **form_data.model_dump(exclude={"id"}),
  286. "specs": specs,
  287. }
  288. log.debug(updated)
  289. tools = Tools.update_tool_by_id(id, updated)
  290. if tools:
  291. return tools
  292. else:
  293. raise HTTPException(
  294. status_code=status.HTTP_400_BAD_REQUEST,
  295. detail=ERROR_MESSAGES.DEFAULT("Error updating tools"),
  296. )
  297. except Exception as e:
  298. raise HTTPException(
  299. status_code=status.HTTP_400_BAD_REQUEST,
  300. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  301. )
  302. ############################
  303. # DeleteToolsById
  304. ############################
  305. @router.delete("/id/{id}/delete", response_model=bool)
  306. async def delete_tools_by_id(
  307. request: Request, id: str, user=Depends(get_verified_user)
  308. ):
  309. tools = Tools.get_tool_by_id(id)
  310. if not tools:
  311. raise HTTPException(
  312. status_code=status.HTTP_401_UNAUTHORIZED,
  313. detail=ERROR_MESSAGES.NOT_FOUND,
  314. )
  315. if (
  316. tools.user_id != user.id
  317. and not has_access(user.id, "write", tools.access_control)
  318. and user.role != "admin"
  319. ):
  320. raise HTTPException(
  321. status_code=status.HTTP_401_UNAUTHORIZED,
  322. detail=ERROR_MESSAGES.UNAUTHORIZED,
  323. )
  324. result = Tools.delete_tool_by_id(id)
  325. if result:
  326. TOOLS = request.app.state.TOOLS
  327. if id in TOOLS:
  328. del TOOLS[id]
  329. return result
  330. ############################
  331. # GetToolValves
  332. ############################
  333. @router.get("/id/{id}/valves", response_model=Optional[dict])
  334. async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
  335. tools = Tools.get_tool_by_id(id)
  336. if tools:
  337. try:
  338. valves = Tools.get_tool_valves_by_id(id)
  339. return valves
  340. except Exception as e:
  341. raise HTTPException(
  342. status_code=status.HTTP_400_BAD_REQUEST,
  343. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  344. )
  345. else:
  346. raise HTTPException(
  347. status_code=status.HTTP_401_UNAUTHORIZED,
  348. detail=ERROR_MESSAGES.NOT_FOUND,
  349. )
  350. ############################
  351. # GetToolValvesSpec
  352. ############################
  353. @router.get("/id/{id}/valves/spec", response_model=Optional[dict])
  354. async def get_tools_valves_spec_by_id(
  355. request: Request, id: str, user=Depends(get_verified_user)
  356. ):
  357. tools = Tools.get_tool_by_id(id)
  358. if tools:
  359. if id in request.app.state.TOOLS:
  360. tools_module = request.app.state.TOOLS[id]
  361. else:
  362. tools_module, _ = load_tool_module_by_id(id)
  363. request.app.state.TOOLS[id] = tools_module
  364. if hasattr(tools_module, "Valves"):
  365. Valves = tools_module.Valves
  366. return Valves.schema()
  367. return None
  368. else:
  369. raise HTTPException(
  370. status_code=status.HTTP_401_UNAUTHORIZED,
  371. detail=ERROR_MESSAGES.NOT_FOUND,
  372. )
  373. ############################
  374. # UpdateToolValves
  375. ############################
  376. @router.post("/id/{id}/valves/update", response_model=Optional[dict])
  377. async def update_tools_valves_by_id(
  378. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  379. ):
  380. tools = Tools.get_tool_by_id(id)
  381. if not tools:
  382. raise HTTPException(
  383. status_code=status.HTTP_401_UNAUTHORIZED,
  384. detail=ERROR_MESSAGES.NOT_FOUND,
  385. )
  386. if (
  387. tools.user_id != user.id
  388. and not has_access(user.id, "write", tools.access_control)
  389. and user.role != "admin"
  390. ):
  391. raise HTTPException(
  392. status_code=status.HTTP_400_BAD_REQUEST,
  393. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  394. )
  395. if id in request.app.state.TOOLS:
  396. tools_module = request.app.state.TOOLS[id]
  397. else:
  398. tools_module, _ = load_tool_module_by_id(id)
  399. request.app.state.TOOLS[id] = tools_module
  400. if not hasattr(tools_module, "Valves"):
  401. raise HTTPException(
  402. status_code=status.HTTP_401_UNAUTHORIZED,
  403. detail=ERROR_MESSAGES.NOT_FOUND,
  404. )
  405. Valves = tools_module.Valves
  406. try:
  407. form_data = {k: v for k, v in form_data.items() if v is not None}
  408. valves = Valves(**form_data)
  409. valves_dict = valves.model_dump(exclude_unset=True)
  410. Tools.update_tool_valves_by_id(id, valves_dict)
  411. return valves_dict
  412. except Exception as e:
  413. log.exception(f"Failed to update tool valves by id {id}: {e}")
  414. raise HTTPException(
  415. status_code=status.HTTP_400_BAD_REQUEST,
  416. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  417. )
  418. ############################
  419. # ToolUserValves
  420. ############################
  421. @router.get("/id/{id}/valves/user", response_model=Optional[dict])
  422. async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
  423. tools = Tools.get_tool_by_id(id)
  424. if tools:
  425. try:
  426. user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
  427. return user_valves
  428. except Exception as e:
  429. raise HTTPException(
  430. status_code=status.HTTP_400_BAD_REQUEST,
  431. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  432. )
  433. else:
  434. raise HTTPException(
  435. status_code=status.HTTP_401_UNAUTHORIZED,
  436. detail=ERROR_MESSAGES.NOT_FOUND,
  437. )
  438. @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
  439. async def get_tools_user_valves_spec_by_id(
  440. request: Request, id: str, user=Depends(get_verified_user)
  441. ):
  442. tools = Tools.get_tool_by_id(id)
  443. if tools:
  444. if id in request.app.state.TOOLS:
  445. tools_module = request.app.state.TOOLS[id]
  446. else:
  447. tools_module, _ = load_tool_module_by_id(id)
  448. request.app.state.TOOLS[id] = tools_module
  449. if hasattr(tools_module, "UserValves"):
  450. UserValves = tools_module.UserValves
  451. return UserValves.schema()
  452. return None
  453. else:
  454. raise HTTPException(
  455. status_code=status.HTTP_401_UNAUTHORIZED,
  456. detail=ERROR_MESSAGES.NOT_FOUND,
  457. )
  458. @router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
  459. async def update_tools_user_valves_by_id(
  460. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  461. ):
  462. tools = Tools.get_tool_by_id(id)
  463. if tools:
  464. if id in request.app.state.TOOLS:
  465. tools_module = request.app.state.TOOLS[id]
  466. else:
  467. tools_module, _ = load_tool_module_by_id(id)
  468. request.app.state.TOOLS[id] = tools_module
  469. if hasattr(tools_module, "UserValves"):
  470. UserValves = tools_module.UserValves
  471. try:
  472. form_data = {k: v for k, v in form_data.items() if v is not None}
  473. user_valves = UserValves(**form_data)
  474. user_valves_dict = user_valves.model_dump(exclude_unset=True)
  475. Tools.update_user_valves_by_id_and_user_id(
  476. id, user.id, user_valves_dict
  477. )
  478. return user_valves_dict
  479. except Exception as e:
  480. log.exception(f"Failed to update user valves by id {id}: {e}")
  481. raise HTTPException(
  482. status_code=status.HTTP_400_BAD_REQUEST,
  483. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  484. )
  485. else:
  486. raise HTTPException(
  487. status_code=status.HTTP_401_UNAUTHORIZED,
  488. detail=ERROR_MESSAGES.NOT_FOUND,
  489. )
  490. else:
  491. raise HTTPException(
  492. status_code=status.HTTP_401_UNAUTHORIZED,
  493. detail=ERROR_MESSAGES.NOT_FOUND,
  494. )