tools.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566
  1. import logging
  2. from pathlib import Path
  3. from typing import Optional
  4. import time
  5. import re
  6. import aiohttp
  7. from pydantic import BaseModel, HttpUrl
  8. from fastapi import APIRouter, Depends, HTTPException, Request, status
  9. from open_webui.models.tools import (
  10. ToolForm,
  11. ToolModel,
  12. ToolResponse,
  13. ToolUserResponse,
  14. Tools,
  15. )
  16. from open_webui.utils.plugin import load_tool_module_by_id, replace_imports
  17. from open_webui.utils.tools import get_tool_specs
  18. from open_webui.utils.auth import get_admin_user, get_verified_user
  19. from open_webui.utils.access_control import has_access, has_permission
  20. from open_webui.utils.tools import get_tool_servers_data
  21. from open_webui.env import SRC_LOG_LEVELS
  22. from open_webui.config import CACHE_DIR, ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS
  23. from open_webui.constants import ERROR_MESSAGES
  24. log = logging.getLogger(__name__)
  25. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  26. router = APIRouter()
  27. ############################
  28. # GetTools
  29. ############################
  30. @router.get("/", response_model=list[ToolUserResponse])
  31. async def get_tools(request: Request, user=Depends(get_verified_user)):
  32. if not request.app.state.TOOL_SERVERS:
  33. # If the tool servers are not set, we need to set them
  34. # This is done only once when the server starts
  35. # This is done to avoid loading the tool servers every time
  36. request.app.state.TOOL_SERVERS = await get_tool_servers_data(
  37. request.app.state.config.TOOL_SERVER_CONNECTIONS
  38. )
  39. tools = Tools.get_tools()
  40. for server in request.app.state.TOOL_SERVERS:
  41. tools.append(
  42. ToolUserResponse(
  43. **{
  44. "id": f"server:{server['idx']}",
  45. "user_id": f"server:{server['idx']}",
  46. "name": server.get("openapi", {})
  47. .get("info", {})
  48. .get("title", "Tool Server"),
  49. "meta": {
  50. "description": server.get("openapi", {})
  51. .get("info", {})
  52. .get("description", ""),
  53. },
  54. "access_control": request.app.state.config.TOOL_SERVER_CONNECTIONS[
  55. server["idx"]
  56. ]
  57. .get("config", {})
  58. .get("access_control", None),
  59. "updated_at": int(time.time()),
  60. "created_at": int(time.time()),
  61. }
  62. )
  63. )
  64. if user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS:
  65. # Admin can see all tools
  66. return tools
  67. else:
  68. tools = [
  69. tool
  70. for tool in tools
  71. if tool.user_id == user.id
  72. or has_access(user.id, "read", tool.access_control)
  73. ]
  74. return tools
  75. ############################
  76. # GetToolList
  77. ############################
  78. @router.get("/list", response_model=list[ToolUserResponse])
  79. async def get_tool_list(user=Depends(get_verified_user)):
  80. if user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS:
  81. tools = Tools.get_tools()
  82. else:
  83. tools = Tools.get_tools_by_user_id(user.id, "write")
  84. return tools
  85. ############################
  86. # LoadFunctionFromLink
  87. ############################
  88. class LoadUrlForm(BaseModel):
  89. url: HttpUrl
  90. def github_url_to_raw_url(url: str) -> str:
  91. # Handle 'tree' (folder) URLs (add main.py at the end)
  92. m1 = re.match(r"https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.*)", url)
  93. if m1:
  94. org, repo, branch, path = m1.groups()
  95. return f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path.rstrip('/')}/main.py"
  96. # Handle 'blob' (file) URLs
  97. m2 = re.match(r"https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.*)", url)
  98. if m2:
  99. org, repo, branch, path = m2.groups()
  100. return (
  101. f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path}"
  102. )
  103. # No match; return as-is
  104. return url
  105. @router.post("/load/url", response_model=Optional[dict])
  106. async def load_tool_from_url(
  107. request: Request, form_data: LoadUrlForm, user=Depends(get_admin_user)
  108. ):
  109. # NOTE: This is NOT a SSRF vulnerability:
  110. # This endpoint is admin-only (see get_admin_user), meant for *trusted* internal use,
  111. # and does NOT accept untrusted user input. Access is enforced by authentication.
  112. url = str(form_data.url)
  113. if not url:
  114. raise HTTPException(status_code=400, detail="Please enter a valid URL")
  115. url = github_url_to_raw_url(url)
  116. url_parts = url.rstrip("/").split("/")
  117. file_name = url_parts[-1]
  118. tool_name = (
  119. file_name[:-3]
  120. if (
  121. file_name.endswith(".py")
  122. and (not file_name.startswith(("main.py", "index.py", "__init__.py")))
  123. )
  124. else url_parts[-2] if len(url_parts) > 1 else "function"
  125. )
  126. try:
  127. async with aiohttp.ClientSession(trust_env=True) as session:
  128. async with session.get(
  129. url, headers={"Content-Type": "application/json"}
  130. ) as resp:
  131. if resp.status != 200:
  132. raise HTTPException(
  133. status_code=resp.status, detail="Failed to fetch the tool"
  134. )
  135. data = await resp.text()
  136. if not data:
  137. raise HTTPException(
  138. status_code=400, detail="No data received from the URL"
  139. )
  140. return {
  141. "name": tool_name,
  142. "content": data,
  143. }
  144. except Exception as e:
  145. raise HTTPException(status_code=500, detail=f"Error importing tool: {e}")
  146. ############################
  147. # ExportTools
  148. ############################
  149. @router.get("/export", response_model=list[ToolModel])
  150. async def export_tools(user=Depends(get_admin_user)):
  151. tools = Tools.get_tools()
  152. return tools
  153. ############################
  154. # CreateNewTools
  155. ############################
  156. @router.post("/create", response_model=Optional[ToolResponse])
  157. async def create_new_tools(
  158. request: Request,
  159. form_data: ToolForm,
  160. user=Depends(get_verified_user),
  161. ):
  162. if user.role != "admin" and not has_permission(
  163. user.id, "workspace.tools", request.app.state.config.USER_PERMISSIONS
  164. ):
  165. raise HTTPException(
  166. status_code=status.HTTP_401_UNAUTHORIZED,
  167. detail=ERROR_MESSAGES.UNAUTHORIZED,
  168. )
  169. if not form_data.id.isidentifier():
  170. raise HTTPException(
  171. status_code=status.HTTP_400_BAD_REQUEST,
  172. detail="Only alphanumeric characters and underscores are allowed in the id",
  173. )
  174. form_data.id = form_data.id.lower()
  175. tools = Tools.get_tool_by_id(form_data.id)
  176. if tools is None:
  177. try:
  178. form_data.content = replace_imports(form_data.content)
  179. tool_module, frontmatter = load_tool_module_by_id(
  180. form_data.id, content=form_data.content
  181. )
  182. form_data.meta.manifest = frontmatter
  183. TOOLS = request.app.state.TOOLS
  184. TOOLS[form_data.id] = tool_module
  185. specs = get_tool_specs(TOOLS[form_data.id])
  186. tools = Tools.insert_new_tool(user.id, form_data, specs)
  187. tool_cache_dir = CACHE_DIR / "tools" / form_data.id
  188. tool_cache_dir.mkdir(parents=True, exist_ok=True)
  189. if tools:
  190. return tools
  191. else:
  192. raise HTTPException(
  193. status_code=status.HTTP_400_BAD_REQUEST,
  194. detail=ERROR_MESSAGES.DEFAULT("Error creating tools"),
  195. )
  196. except Exception as e:
  197. log.exception(f"Failed to load the tool by id {form_data.id}: {e}")
  198. raise HTTPException(
  199. status_code=status.HTTP_400_BAD_REQUEST,
  200. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  201. )
  202. else:
  203. raise HTTPException(
  204. status_code=status.HTTP_400_BAD_REQUEST,
  205. detail=ERROR_MESSAGES.ID_TAKEN,
  206. )
  207. ############################
  208. # GetToolsById
  209. ############################
  210. @router.get("/id/{id}", response_model=Optional[ToolModel])
  211. async def get_tools_by_id(id: str, user=Depends(get_verified_user)):
  212. tools = Tools.get_tool_by_id(id)
  213. if tools:
  214. if (
  215. user.role == "admin"
  216. or tools.user_id == user.id
  217. or has_access(user.id, "read", tools.access_control)
  218. ):
  219. return tools
  220. else:
  221. raise HTTPException(
  222. status_code=status.HTTP_401_UNAUTHORIZED,
  223. detail=ERROR_MESSAGES.NOT_FOUND,
  224. )
  225. ############################
  226. # UpdateToolsById
  227. ############################
  228. @router.post("/id/{id}/update", response_model=Optional[ToolModel])
  229. async def update_tools_by_id(
  230. request: Request,
  231. id: str,
  232. form_data: ToolForm,
  233. user=Depends(get_verified_user),
  234. ):
  235. tools = Tools.get_tool_by_id(id)
  236. if not tools:
  237. raise HTTPException(
  238. status_code=status.HTTP_401_UNAUTHORIZED,
  239. detail=ERROR_MESSAGES.NOT_FOUND,
  240. )
  241. # Is the user the original creator, in a group with write access, or an admin
  242. if (
  243. tools.user_id != user.id
  244. and not has_access(user.id, "write", tools.access_control)
  245. and user.role != "admin"
  246. ):
  247. raise HTTPException(
  248. status_code=status.HTTP_401_UNAUTHORIZED,
  249. detail=ERROR_MESSAGES.UNAUTHORIZED,
  250. )
  251. try:
  252. form_data.content = replace_imports(form_data.content)
  253. tool_module, frontmatter = load_tool_module_by_id(id, content=form_data.content)
  254. form_data.meta.manifest = frontmatter
  255. TOOLS = request.app.state.TOOLS
  256. TOOLS[id] = tool_module
  257. specs = get_tool_specs(TOOLS[id])
  258. updated = {
  259. **form_data.model_dump(exclude={"id"}),
  260. "specs": specs,
  261. }
  262. log.debug(updated)
  263. tools = Tools.update_tool_by_id(id, updated)
  264. if tools:
  265. return tools
  266. else:
  267. raise HTTPException(
  268. status_code=status.HTTP_400_BAD_REQUEST,
  269. detail=ERROR_MESSAGES.DEFAULT("Error updating tools"),
  270. )
  271. except Exception as e:
  272. raise HTTPException(
  273. status_code=status.HTTP_400_BAD_REQUEST,
  274. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  275. )
  276. ############################
  277. # DeleteToolsById
  278. ############################
  279. @router.delete("/id/{id}/delete", response_model=bool)
  280. async def delete_tools_by_id(
  281. request: Request, id: str, user=Depends(get_verified_user)
  282. ):
  283. tools = Tools.get_tool_by_id(id)
  284. if not tools:
  285. raise HTTPException(
  286. status_code=status.HTTP_401_UNAUTHORIZED,
  287. detail=ERROR_MESSAGES.NOT_FOUND,
  288. )
  289. if (
  290. tools.user_id != user.id
  291. and not has_access(user.id, "write", tools.access_control)
  292. and user.role != "admin"
  293. ):
  294. raise HTTPException(
  295. status_code=status.HTTP_401_UNAUTHORIZED,
  296. detail=ERROR_MESSAGES.UNAUTHORIZED,
  297. )
  298. result = Tools.delete_tool_by_id(id)
  299. if result:
  300. TOOLS = request.app.state.TOOLS
  301. if id in TOOLS:
  302. del TOOLS[id]
  303. return result
  304. ############################
  305. # GetToolValves
  306. ############################
  307. @router.get("/id/{id}/valves", response_model=Optional[dict])
  308. async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
  309. tools = Tools.get_tool_by_id(id)
  310. if tools:
  311. try:
  312. valves = Tools.get_tool_valves_by_id(id)
  313. return valves
  314. except Exception as e:
  315. raise HTTPException(
  316. status_code=status.HTTP_400_BAD_REQUEST,
  317. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  318. )
  319. else:
  320. raise HTTPException(
  321. status_code=status.HTTP_401_UNAUTHORIZED,
  322. detail=ERROR_MESSAGES.NOT_FOUND,
  323. )
  324. ############################
  325. # GetToolValvesSpec
  326. ############################
  327. @router.get("/id/{id}/valves/spec", response_model=Optional[dict])
  328. async def get_tools_valves_spec_by_id(
  329. request: Request, id: str, user=Depends(get_verified_user)
  330. ):
  331. tools = Tools.get_tool_by_id(id)
  332. if tools:
  333. if id in request.app.state.TOOLS:
  334. tools_module = request.app.state.TOOLS[id]
  335. else:
  336. tools_module, _ = load_tool_module_by_id(id)
  337. request.app.state.TOOLS[id] = tools_module
  338. if hasattr(tools_module, "Valves"):
  339. Valves = tools_module.Valves
  340. return Valves.schema()
  341. return None
  342. else:
  343. raise HTTPException(
  344. status_code=status.HTTP_401_UNAUTHORIZED,
  345. detail=ERROR_MESSAGES.NOT_FOUND,
  346. )
  347. ############################
  348. # UpdateToolValves
  349. ############################
  350. @router.post("/id/{id}/valves/update", response_model=Optional[dict])
  351. async def update_tools_valves_by_id(
  352. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  353. ):
  354. tools = Tools.get_tool_by_id(id)
  355. if not tools:
  356. raise HTTPException(
  357. status_code=status.HTTP_401_UNAUTHORIZED,
  358. detail=ERROR_MESSAGES.NOT_FOUND,
  359. )
  360. if (
  361. tools.user_id != user.id
  362. and not has_access(user.id, "write", tools.access_control)
  363. and user.role != "admin"
  364. ):
  365. raise HTTPException(
  366. status_code=status.HTTP_400_BAD_REQUEST,
  367. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  368. )
  369. if id in request.app.state.TOOLS:
  370. tools_module = request.app.state.TOOLS[id]
  371. else:
  372. tools_module, _ = load_tool_module_by_id(id)
  373. request.app.state.TOOLS[id] = tools_module
  374. if not hasattr(tools_module, "Valves"):
  375. raise HTTPException(
  376. status_code=status.HTTP_401_UNAUTHORIZED,
  377. detail=ERROR_MESSAGES.NOT_FOUND,
  378. )
  379. Valves = tools_module.Valves
  380. try:
  381. form_data = {k: v for k, v in form_data.items() if v is not None}
  382. valves = Valves(**form_data)
  383. Tools.update_tool_valves_by_id(id, valves.model_dump())
  384. return valves.model_dump()
  385. except Exception as e:
  386. log.exception(f"Failed to update tool valves by id {id}: {e}")
  387. raise HTTPException(
  388. status_code=status.HTTP_400_BAD_REQUEST,
  389. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  390. )
  391. ############################
  392. # ToolUserValves
  393. ############################
  394. @router.get("/id/{id}/valves/user", response_model=Optional[dict])
  395. async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
  396. tools = Tools.get_tool_by_id(id)
  397. if tools:
  398. try:
  399. user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
  400. return user_valves
  401. except Exception as e:
  402. raise HTTPException(
  403. status_code=status.HTTP_400_BAD_REQUEST,
  404. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  405. )
  406. else:
  407. raise HTTPException(
  408. status_code=status.HTTP_401_UNAUTHORIZED,
  409. detail=ERROR_MESSAGES.NOT_FOUND,
  410. )
  411. @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
  412. async def get_tools_user_valves_spec_by_id(
  413. request: Request, id: str, user=Depends(get_verified_user)
  414. ):
  415. tools = Tools.get_tool_by_id(id)
  416. if tools:
  417. if id in request.app.state.TOOLS:
  418. tools_module = request.app.state.TOOLS[id]
  419. else:
  420. tools_module, _ = load_tool_module_by_id(id)
  421. request.app.state.TOOLS[id] = tools_module
  422. if hasattr(tools_module, "UserValves"):
  423. UserValves = tools_module.UserValves
  424. return UserValves.schema()
  425. return None
  426. else:
  427. raise HTTPException(
  428. status_code=status.HTTP_401_UNAUTHORIZED,
  429. detail=ERROR_MESSAGES.NOT_FOUND,
  430. )
  431. @router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
  432. async def update_tools_user_valves_by_id(
  433. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  434. ):
  435. tools = Tools.get_tool_by_id(id)
  436. if tools:
  437. if id in request.app.state.TOOLS:
  438. tools_module = request.app.state.TOOLS[id]
  439. else:
  440. tools_module, _ = load_tool_module_by_id(id)
  441. request.app.state.TOOLS[id] = tools_module
  442. if hasattr(tools_module, "UserValves"):
  443. UserValves = tools_module.UserValves
  444. try:
  445. form_data = {k: v for k, v in form_data.items() if v is not None}
  446. user_valves = UserValves(**form_data)
  447. Tools.update_user_valves_by_id_and_user_id(
  448. id, user.id, user_valves.model_dump()
  449. )
  450. return user_valves.model_dump()
  451. except Exception as e:
  452. log.exception(f"Failed to update user valves by id {id}: {e}")
  453. raise HTTPException(
  454. status_code=status.HTTP_400_BAD_REQUEST,
  455. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  456. )
  457. else:
  458. raise HTTPException(
  459. status_code=status.HTTP_401_UNAUTHORIZED,
  460. detail=ERROR_MESSAGES.NOT_FOUND,
  461. )
  462. else:
  463. raise HTTPException(
  464. status_code=status.HTTP_401_UNAUTHORIZED,
  465. detail=ERROR_MESSAGES.NOT_FOUND,
  466. )