tools.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560
  1. import logging
  2. from pathlib import Path
  3. from typing import Optional
  4. import time
  5. import re
  6. import aiohttp
  7. from open_webui.models.groups import Groups
  8. from pydantic import BaseModel, HttpUrl
  9. from fastapi import APIRouter, Depends, HTTPException, Request, status
  10. from open_webui.models.tools import (
  11. ToolForm,
  12. ToolModel,
  13. ToolResponse,
  14. ToolUserResponse,
  15. Tools,
  16. )
  17. from open_webui.utils.plugin import load_tool_module_by_id, replace_imports
  18. from open_webui.utils.tools import get_tool_specs
  19. from open_webui.utils.auth import get_admin_user, get_verified_user
  20. from open_webui.utils.access_control import has_access, has_permission
  21. from open_webui.utils.tools import get_tool_servers
  22. from open_webui.env import SRC_LOG_LEVELS
  23. from open_webui.config import CACHE_DIR, BYPASS_ADMIN_ACCESS_CONTROL
  24. from open_webui.constants import ERROR_MESSAGES
  25. log = logging.getLogger(__name__)
  26. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  27. router = APIRouter()
  28. ############################
  29. # GetTools
  30. ############################
  31. @router.get("/", response_model=list[ToolUserResponse])
  32. async def get_tools(request: Request, user=Depends(get_verified_user)):
  33. tools = Tools.get_tools()
  34. for server in await get_tool_servers(request):
  35. tools.append(
  36. ToolUserResponse(
  37. **{
  38. "id": f"server:{server.get('id')}",
  39. "user_id": f"server:{server.get('id')}",
  40. "name": server.get("openapi", {})
  41. .get("info", {})
  42. .get("title", "Tool Server"),
  43. "meta": {
  44. "description": server.get("openapi", {})
  45. .get("info", {})
  46. .get("description", ""),
  47. },
  48. "access_control": request.app.state.config.TOOL_SERVER_CONNECTIONS[
  49. server.get("idx", 0)
  50. ]
  51. .get("config", {})
  52. .get("access_control", None),
  53. "updated_at": int(time.time()),
  54. "created_at": int(time.time()),
  55. }
  56. )
  57. )
  58. if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
  59. # Admin can see all tools
  60. return tools
  61. else:
  62. user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)}
  63. tools = [
  64. tool
  65. for tool in tools
  66. if tool.user_id == user.id
  67. or has_access(user.id, "read", tool.access_control, user_group_ids)
  68. ]
  69. return tools
  70. ############################
  71. # GetToolList
  72. ############################
  73. @router.get("/list", response_model=list[ToolUserResponse])
  74. async def get_tool_list(user=Depends(get_verified_user)):
  75. if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
  76. tools = Tools.get_tools()
  77. else:
  78. tools = Tools.get_tools_by_user_id(user.id, "write")
  79. return tools
  80. ############################
  81. # LoadFunctionFromLink
  82. ############################
  83. class LoadUrlForm(BaseModel):
  84. url: HttpUrl
  85. def github_url_to_raw_url(url: str) -> str:
  86. # Handle 'tree' (folder) URLs (add main.py at the end)
  87. m1 = re.match(r"https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.*)", url)
  88. if m1:
  89. org, repo, branch, path = m1.groups()
  90. return f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path.rstrip('/')}/main.py"
  91. # Handle 'blob' (file) URLs
  92. m2 = re.match(r"https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.*)", url)
  93. if m2:
  94. org, repo, branch, path = m2.groups()
  95. return (
  96. f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path}"
  97. )
  98. # No match; return as-is
  99. return url
  100. @router.post("/load/url", response_model=Optional[dict])
  101. async def load_tool_from_url(
  102. request: Request, form_data: LoadUrlForm, user=Depends(get_admin_user)
  103. ):
  104. # NOTE: This is NOT a SSRF vulnerability:
  105. # This endpoint is admin-only (see get_admin_user), meant for *trusted* internal use,
  106. # and does NOT accept untrusted user input. Access is enforced by authentication.
  107. url = str(form_data.url)
  108. if not url:
  109. raise HTTPException(status_code=400, detail="Please enter a valid URL")
  110. url = github_url_to_raw_url(url)
  111. url_parts = url.rstrip("/").split("/")
  112. file_name = url_parts[-1]
  113. tool_name = (
  114. file_name[:-3]
  115. if (
  116. file_name.endswith(".py")
  117. and (not file_name.startswith(("main.py", "index.py", "__init__.py")))
  118. )
  119. else url_parts[-2] if len(url_parts) > 1 else "function"
  120. )
  121. try:
  122. async with aiohttp.ClientSession(trust_env=True) as session:
  123. async with session.get(
  124. url, headers={"Content-Type": "application/json"}
  125. ) as resp:
  126. if resp.status != 200:
  127. raise HTTPException(
  128. status_code=resp.status, detail="Failed to fetch the tool"
  129. )
  130. data = await resp.text()
  131. if not data:
  132. raise HTTPException(
  133. status_code=400, detail="No data received from the URL"
  134. )
  135. return {
  136. "name": tool_name,
  137. "content": data,
  138. }
  139. except Exception as e:
  140. raise HTTPException(status_code=500, detail=f"Error importing tool: {e}")
  141. ############################
  142. # ExportTools
  143. ############################
  144. @router.get("/export", response_model=list[ToolModel])
  145. async def export_tools(user=Depends(get_admin_user)):
  146. tools = Tools.get_tools()
  147. return tools
  148. ############################
  149. # CreateNewTools
  150. ############################
  151. @router.post("/create", response_model=Optional[ToolResponse])
  152. async def create_new_tools(
  153. request: Request,
  154. form_data: ToolForm,
  155. user=Depends(get_verified_user),
  156. ):
  157. if user.role != "admin" and not has_permission(
  158. user.id, "workspace.tools", request.app.state.config.USER_PERMISSIONS
  159. ):
  160. raise HTTPException(
  161. status_code=status.HTTP_401_UNAUTHORIZED,
  162. detail=ERROR_MESSAGES.UNAUTHORIZED,
  163. )
  164. if not form_data.id.isidentifier():
  165. raise HTTPException(
  166. status_code=status.HTTP_400_BAD_REQUEST,
  167. detail="Only alphanumeric characters and underscores are allowed in the id",
  168. )
  169. form_data.id = form_data.id.lower()
  170. tools = Tools.get_tool_by_id(form_data.id)
  171. if tools is None:
  172. try:
  173. form_data.content = replace_imports(form_data.content)
  174. tool_module, frontmatter = load_tool_module_by_id(
  175. form_data.id, content=form_data.content
  176. )
  177. form_data.meta.manifest = frontmatter
  178. TOOLS = request.app.state.TOOLS
  179. TOOLS[form_data.id] = tool_module
  180. specs = get_tool_specs(TOOLS[form_data.id])
  181. tools = Tools.insert_new_tool(user.id, form_data, specs)
  182. tool_cache_dir = CACHE_DIR / "tools" / form_data.id
  183. tool_cache_dir.mkdir(parents=True, exist_ok=True)
  184. if tools:
  185. return tools
  186. else:
  187. raise HTTPException(
  188. status_code=status.HTTP_400_BAD_REQUEST,
  189. detail=ERROR_MESSAGES.DEFAULT("Error creating tools"),
  190. )
  191. except Exception as e:
  192. log.exception(f"Failed to load the tool by id {form_data.id}: {e}")
  193. raise HTTPException(
  194. status_code=status.HTTP_400_BAD_REQUEST,
  195. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  196. )
  197. else:
  198. raise HTTPException(
  199. status_code=status.HTTP_400_BAD_REQUEST,
  200. detail=ERROR_MESSAGES.ID_TAKEN,
  201. )
  202. ############################
  203. # GetToolsById
  204. ############################
  205. @router.get("/id/{id}", response_model=Optional[ToolModel])
  206. async def get_tools_by_id(id: str, user=Depends(get_verified_user)):
  207. tools = Tools.get_tool_by_id(id)
  208. if tools:
  209. if (
  210. user.role == "admin"
  211. or tools.user_id == user.id
  212. or has_access(user.id, "read", tools.access_control)
  213. ):
  214. return tools
  215. else:
  216. raise HTTPException(
  217. status_code=status.HTTP_401_UNAUTHORIZED,
  218. detail=ERROR_MESSAGES.NOT_FOUND,
  219. )
  220. ############################
  221. # UpdateToolsById
  222. ############################
  223. @router.post("/id/{id}/update", response_model=Optional[ToolModel])
  224. async def update_tools_by_id(
  225. request: Request,
  226. id: str,
  227. form_data: ToolForm,
  228. user=Depends(get_verified_user),
  229. ):
  230. tools = Tools.get_tool_by_id(id)
  231. if not tools:
  232. raise HTTPException(
  233. status_code=status.HTTP_401_UNAUTHORIZED,
  234. detail=ERROR_MESSAGES.NOT_FOUND,
  235. )
  236. # Is the user the original creator, in a group with write access, or an admin
  237. if (
  238. tools.user_id != user.id
  239. and not has_access(user.id, "write", tools.access_control)
  240. and user.role != "admin"
  241. ):
  242. raise HTTPException(
  243. status_code=status.HTTP_401_UNAUTHORIZED,
  244. detail=ERROR_MESSAGES.UNAUTHORIZED,
  245. )
  246. try:
  247. form_data.content = replace_imports(form_data.content)
  248. tool_module, frontmatter = load_tool_module_by_id(id, content=form_data.content)
  249. form_data.meta.manifest = frontmatter
  250. TOOLS = request.app.state.TOOLS
  251. TOOLS[id] = tool_module
  252. specs = get_tool_specs(TOOLS[id])
  253. updated = {
  254. **form_data.model_dump(exclude={"id"}),
  255. "specs": specs,
  256. }
  257. log.debug(updated)
  258. tools = Tools.update_tool_by_id(id, updated)
  259. if tools:
  260. return tools
  261. else:
  262. raise HTTPException(
  263. status_code=status.HTTP_400_BAD_REQUEST,
  264. detail=ERROR_MESSAGES.DEFAULT("Error updating tools"),
  265. )
  266. except Exception as e:
  267. raise HTTPException(
  268. status_code=status.HTTP_400_BAD_REQUEST,
  269. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  270. )
  271. ############################
  272. # DeleteToolsById
  273. ############################
  274. @router.delete("/id/{id}/delete", response_model=bool)
  275. async def delete_tools_by_id(
  276. request: Request, id: str, user=Depends(get_verified_user)
  277. ):
  278. tools = Tools.get_tool_by_id(id)
  279. if not tools:
  280. raise HTTPException(
  281. status_code=status.HTTP_401_UNAUTHORIZED,
  282. detail=ERROR_MESSAGES.NOT_FOUND,
  283. )
  284. if (
  285. tools.user_id != user.id
  286. and not has_access(user.id, "write", tools.access_control)
  287. and user.role != "admin"
  288. ):
  289. raise HTTPException(
  290. status_code=status.HTTP_401_UNAUTHORIZED,
  291. detail=ERROR_MESSAGES.UNAUTHORIZED,
  292. )
  293. result = Tools.delete_tool_by_id(id)
  294. if result:
  295. TOOLS = request.app.state.TOOLS
  296. if id in TOOLS:
  297. del TOOLS[id]
  298. return result
  299. ############################
  300. # GetToolValves
  301. ############################
  302. @router.get("/id/{id}/valves", response_model=Optional[dict])
  303. async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
  304. tools = Tools.get_tool_by_id(id)
  305. if tools:
  306. try:
  307. valves = Tools.get_tool_valves_by_id(id)
  308. return valves
  309. except Exception as e:
  310. raise HTTPException(
  311. status_code=status.HTTP_400_BAD_REQUEST,
  312. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  313. )
  314. else:
  315. raise HTTPException(
  316. status_code=status.HTTP_401_UNAUTHORIZED,
  317. detail=ERROR_MESSAGES.NOT_FOUND,
  318. )
  319. ############################
  320. # GetToolValvesSpec
  321. ############################
  322. @router.get("/id/{id}/valves/spec", response_model=Optional[dict])
  323. async def get_tools_valves_spec_by_id(
  324. request: Request, id: str, user=Depends(get_verified_user)
  325. ):
  326. tools = Tools.get_tool_by_id(id)
  327. if tools:
  328. if id in request.app.state.TOOLS:
  329. tools_module = request.app.state.TOOLS[id]
  330. else:
  331. tools_module, _ = load_tool_module_by_id(id)
  332. request.app.state.TOOLS[id] = tools_module
  333. if hasattr(tools_module, "Valves"):
  334. Valves = tools_module.Valves
  335. return Valves.schema()
  336. return None
  337. else:
  338. raise HTTPException(
  339. status_code=status.HTTP_401_UNAUTHORIZED,
  340. detail=ERROR_MESSAGES.NOT_FOUND,
  341. )
  342. ############################
  343. # UpdateToolValves
  344. ############################
  345. @router.post("/id/{id}/valves/update", response_model=Optional[dict])
  346. async def update_tools_valves_by_id(
  347. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  348. ):
  349. tools = Tools.get_tool_by_id(id)
  350. if not tools:
  351. raise HTTPException(
  352. status_code=status.HTTP_401_UNAUTHORIZED,
  353. detail=ERROR_MESSAGES.NOT_FOUND,
  354. )
  355. if (
  356. tools.user_id != user.id
  357. and not has_access(user.id, "write", tools.access_control)
  358. and user.role != "admin"
  359. ):
  360. raise HTTPException(
  361. status_code=status.HTTP_400_BAD_REQUEST,
  362. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  363. )
  364. if id in request.app.state.TOOLS:
  365. tools_module = request.app.state.TOOLS[id]
  366. else:
  367. tools_module, _ = load_tool_module_by_id(id)
  368. request.app.state.TOOLS[id] = tools_module
  369. if not hasattr(tools_module, "Valves"):
  370. raise HTTPException(
  371. status_code=status.HTTP_401_UNAUTHORIZED,
  372. detail=ERROR_MESSAGES.NOT_FOUND,
  373. )
  374. Valves = tools_module.Valves
  375. try:
  376. form_data = {k: v for k, v in form_data.items() if v is not None}
  377. valves = Valves(**form_data)
  378. Tools.update_tool_valves_by_id(id, valves.model_dump())
  379. return valves.model_dump()
  380. except Exception as e:
  381. log.exception(f"Failed to update tool valves by id {id}: {e}")
  382. raise HTTPException(
  383. status_code=status.HTTP_400_BAD_REQUEST,
  384. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  385. )
  386. ############################
  387. # ToolUserValves
  388. ############################
  389. @router.get("/id/{id}/valves/user", response_model=Optional[dict])
  390. async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
  391. tools = Tools.get_tool_by_id(id)
  392. if tools:
  393. try:
  394. user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
  395. return user_valves
  396. except Exception as e:
  397. raise HTTPException(
  398. status_code=status.HTTP_400_BAD_REQUEST,
  399. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  400. )
  401. else:
  402. raise HTTPException(
  403. status_code=status.HTTP_401_UNAUTHORIZED,
  404. detail=ERROR_MESSAGES.NOT_FOUND,
  405. )
  406. @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
  407. async def get_tools_user_valves_spec_by_id(
  408. request: Request, id: str, user=Depends(get_verified_user)
  409. ):
  410. tools = Tools.get_tool_by_id(id)
  411. if tools:
  412. if id in request.app.state.TOOLS:
  413. tools_module = request.app.state.TOOLS[id]
  414. else:
  415. tools_module, _ = load_tool_module_by_id(id)
  416. request.app.state.TOOLS[id] = tools_module
  417. if hasattr(tools_module, "UserValves"):
  418. UserValves = tools_module.UserValves
  419. return UserValves.schema()
  420. return None
  421. else:
  422. raise HTTPException(
  423. status_code=status.HTTP_401_UNAUTHORIZED,
  424. detail=ERROR_MESSAGES.NOT_FOUND,
  425. )
  426. @router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
  427. async def update_tools_user_valves_by_id(
  428. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  429. ):
  430. tools = Tools.get_tool_by_id(id)
  431. if tools:
  432. if id in request.app.state.TOOLS:
  433. tools_module = request.app.state.TOOLS[id]
  434. else:
  435. tools_module, _ = load_tool_module_by_id(id)
  436. request.app.state.TOOLS[id] = tools_module
  437. if hasattr(tools_module, "UserValves"):
  438. UserValves = tools_module.UserValves
  439. try:
  440. form_data = {k: v for k, v in form_data.items() if v is not None}
  441. user_valves = UserValves(**form_data)
  442. Tools.update_user_valves_by_id_and_user_id(
  443. id, user.id, user_valves.model_dump()
  444. )
  445. return user_valves.model_dump()
  446. except Exception as e:
  447. log.exception(f"Failed to update user valves by id {id}: {e}")
  448. raise HTTPException(
  449. status_code=status.HTTP_400_BAD_REQUEST,
  450. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  451. )
  452. else:
  453. raise HTTPException(
  454. status_code=status.HTTP_401_UNAUTHORIZED,
  455. detail=ERROR_MESSAGES.NOT_FOUND,
  456. )
  457. else:
  458. raise HTTPException(
  459. status_code=status.HTTP_401_UNAUTHORIZED,
  460. detail=ERROR_MESSAGES.NOT_FOUND,
  461. )