1
0

tools.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558
  1. import logging
  2. from pathlib import Path
  3. from typing import Optional
  4. import time
  5. import re
  6. import aiohttp
  7. from pydantic import BaseModel, HttpUrl
  8. from fastapi import APIRouter, Depends, HTTPException, Request, status
  9. from open_webui.models.tools import (
  10. ToolForm,
  11. ToolModel,
  12. ToolResponse,
  13. ToolUserResponse,
  14. Tools,
  15. )
  16. from open_webui.utils.plugin import load_tool_module_by_id, replace_imports
  17. from open_webui.utils.tools import get_tool_specs
  18. from open_webui.utils.auth import get_admin_user, get_verified_user
  19. from open_webui.utils.access_control import has_access, has_permission
  20. from open_webui.utils.tools import get_tool_servers
  21. from open_webui.env import SRC_LOG_LEVELS
  22. from open_webui.config import CACHE_DIR, ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS
  23. from open_webui.constants import ERROR_MESSAGES
  24. log = logging.getLogger(__name__)
  25. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  26. router = APIRouter()
  27. ############################
  28. # GetTools
  29. ############################
  30. @router.get("/", response_model=list[ToolUserResponse])
  31. async def get_tools(request: Request, user=Depends(get_verified_user)):
  32. tools = Tools.get_tools()
  33. for server in await get_tool_servers(request):
  34. tools.append(
  35. ToolUserResponse(
  36. **{
  37. "id": f"server:{server.get('id')}",
  38. "user_id": f"server:{server.get('id')}",
  39. "name": server.get("openapi", {})
  40. .get("info", {})
  41. .get("title", "Tool Server"),
  42. "meta": {
  43. "description": server.get("openapi", {})
  44. .get("info", {})
  45. .get("description", ""),
  46. },
  47. "access_control": request.app.state.config.TOOL_SERVER_CONNECTIONS[
  48. server.get("idx", 0)
  49. ]
  50. .get("config", {})
  51. .get("access_control", None),
  52. "updated_at": int(time.time()),
  53. "created_at": int(time.time()),
  54. }
  55. )
  56. )
  57. if user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS:
  58. # Admin can see all tools
  59. return tools
  60. else:
  61. tools = [
  62. tool
  63. for tool in tools
  64. if tool.user_id == user.id
  65. or has_access(user.id, "read", tool.access_control)
  66. ]
  67. return tools
  68. ############################
  69. # GetToolList
  70. ############################
  71. @router.get("/list", response_model=list[ToolUserResponse])
  72. async def get_tool_list(user=Depends(get_verified_user)):
  73. if user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS:
  74. tools = Tools.get_tools()
  75. else:
  76. tools = Tools.get_tools_by_user_id(user.id, "write")
  77. return tools
  78. ############################
  79. # LoadFunctionFromLink
  80. ############################
  81. class LoadUrlForm(BaseModel):
  82. url: HttpUrl
  83. def github_url_to_raw_url(url: str) -> str:
  84. # Handle 'tree' (folder) URLs (add main.py at the end)
  85. m1 = re.match(r"https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.*)", url)
  86. if m1:
  87. org, repo, branch, path = m1.groups()
  88. return f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path.rstrip('/')}/main.py"
  89. # Handle 'blob' (file) URLs
  90. m2 = re.match(r"https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.*)", url)
  91. if m2:
  92. org, repo, branch, path = m2.groups()
  93. return (
  94. f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path}"
  95. )
  96. # No match; return as-is
  97. return url
  98. @router.post("/load/url", response_model=Optional[dict])
  99. async def load_tool_from_url(
  100. request: Request, form_data: LoadUrlForm, user=Depends(get_admin_user)
  101. ):
  102. # NOTE: This is NOT a SSRF vulnerability:
  103. # This endpoint is admin-only (see get_admin_user), meant for *trusted* internal use,
  104. # and does NOT accept untrusted user input. Access is enforced by authentication.
  105. url = str(form_data.url)
  106. if not url:
  107. raise HTTPException(status_code=400, detail="Please enter a valid URL")
  108. url = github_url_to_raw_url(url)
  109. url_parts = url.rstrip("/").split("/")
  110. file_name = url_parts[-1]
  111. tool_name = (
  112. file_name[:-3]
  113. if (
  114. file_name.endswith(".py")
  115. and (not file_name.startswith(("main.py", "index.py", "__init__.py")))
  116. )
  117. else url_parts[-2] if len(url_parts) > 1 else "function"
  118. )
  119. try:
  120. async with aiohttp.ClientSession(trust_env=True) as session:
  121. async with session.get(
  122. url, headers={"Content-Type": "application/json"}
  123. ) as resp:
  124. if resp.status != 200:
  125. raise HTTPException(
  126. status_code=resp.status, detail="Failed to fetch the tool"
  127. )
  128. data = await resp.text()
  129. if not data:
  130. raise HTTPException(
  131. status_code=400, detail="No data received from the URL"
  132. )
  133. return {
  134. "name": tool_name,
  135. "content": data,
  136. }
  137. except Exception as e:
  138. raise HTTPException(status_code=500, detail=f"Error importing tool: {e}")
  139. ############################
  140. # ExportTools
  141. ############################
  142. @router.get("/export", response_model=list[ToolModel])
  143. async def export_tools(user=Depends(get_admin_user)):
  144. tools = Tools.get_tools()
  145. return tools
  146. ############################
  147. # CreateNewTools
  148. ############################
  149. @router.post("/create", response_model=Optional[ToolResponse])
  150. async def create_new_tools(
  151. request: Request,
  152. form_data: ToolForm,
  153. user=Depends(get_verified_user),
  154. ):
  155. if user.role != "admin" and not has_permission(
  156. user.id, "workspace.tools", request.app.state.config.USER_PERMISSIONS
  157. ):
  158. raise HTTPException(
  159. status_code=status.HTTP_401_UNAUTHORIZED,
  160. detail=ERROR_MESSAGES.UNAUTHORIZED,
  161. )
  162. if not form_data.id.isidentifier():
  163. raise HTTPException(
  164. status_code=status.HTTP_400_BAD_REQUEST,
  165. detail="Only alphanumeric characters and underscores are allowed in the id",
  166. )
  167. form_data.id = form_data.id.lower()
  168. tools = Tools.get_tool_by_id(form_data.id)
  169. if tools is None:
  170. try:
  171. form_data.content = replace_imports(form_data.content)
  172. tool_module, frontmatter = load_tool_module_by_id(
  173. form_data.id, content=form_data.content
  174. )
  175. form_data.meta.manifest = frontmatter
  176. TOOLS = request.app.state.TOOLS
  177. TOOLS[form_data.id] = tool_module
  178. specs = get_tool_specs(TOOLS[form_data.id])
  179. tools = Tools.insert_new_tool(user.id, form_data, specs)
  180. tool_cache_dir = CACHE_DIR / "tools" / form_data.id
  181. tool_cache_dir.mkdir(parents=True, exist_ok=True)
  182. if tools:
  183. return tools
  184. else:
  185. raise HTTPException(
  186. status_code=status.HTTP_400_BAD_REQUEST,
  187. detail=ERROR_MESSAGES.DEFAULT("Error creating tools"),
  188. )
  189. except Exception as e:
  190. log.exception(f"Failed to load the tool by id {form_data.id}: {e}")
  191. raise HTTPException(
  192. status_code=status.HTTP_400_BAD_REQUEST,
  193. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  194. )
  195. else:
  196. raise HTTPException(
  197. status_code=status.HTTP_400_BAD_REQUEST,
  198. detail=ERROR_MESSAGES.ID_TAKEN,
  199. )
  200. ############################
  201. # GetToolsById
  202. ############################
  203. @router.get("/id/{id}", response_model=Optional[ToolModel])
  204. async def get_tools_by_id(id: str, user=Depends(get_verified_user)):
  205. tools = Tools.get_tool_by_id(id)
  206. if tools:
  207. if (
  208. user.role == "admin"
  209. or tools.user_id == user.id
  210. or has_access(user.id, "read", tools.access_control)
  211. ):
  212. return tools
  213. else:
  214. raise HTTPException(
  215. status_code=status.HTTP_401_UNAUTHORIZED,
  216. detail=ERROR_MESSAGES.NOT_FOUND,
  217. )
  218. ############################
  219. # UpdateToolsById
  220. ############################
  221. @router.post("/id/{id}/update", response_model=Optional[ToolModel])
  222. async def update_tools_by_id(
  223. request: Request,
  224. id: str,
  225. form_data: ToolForm,
  226. user=Depends(get_verified_user),
  227. ):
  228. tools = Tools.get_tool_by_id(id)
  229. if not tools:
  230. raise HTTPException(
  231. status_code=status.HTTP_401_UNAUTHORIZED,
  232. detail=ERROR_MESSAGES.NOT_FOUND,
  233. )
  234. # Is the user the original creator, in a group with write access, or an admin
  235. if (
  236. tools.user_id != user.id
  237. and not has_access(user.id, "write", tools.access_control)
  238. and user.role != "admin"
  239. ):
  240. raise HTTPException(
  241. status_code=status.HTTP_401_UNAUTHORIZED,
  242. detail=ERROR_MESSAGES.UNAUTHORIZED,
  243. )
  244. try:
  245. form_data.content = replace_imports(form_data.content)
  246. tool_module, frontmatter = load_tool_module_by_id(id, content=form_data.content)
  247. form_data.meta.manifest = frontmatter
  248. TOOLS = request.app.state.TOOLS
  249. TOOLS[id] = tool_module
  250. specs = get_tool_specs(TOOLS[id])
  251. updated = {
  252. **form_data.model_dump(exclude={"id"}),
  253. "specs": specs,
  254. }
  255. log.debug(updated)
  256. tools = Tools.update_tool_by_id(id, updated)
  257. if tools:
  258. return tools
  259. else:
  260. raise HTTPException(
  261. status_code=status.HTTP_400_BAD_REQUEST,
  262. detail=ERROR_MESSAGES.DEFAULT("Error updating tools"),
  263. )
  264. except Exception as e:
  265. raise HTTPException(
  266. status_code=status.HTTP_400_BAD_REQUEST,
  267. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  268. )
  269. ############################
  270. # DeleteToolsById
  271. ############################
  272. @router.delete("/id/{id}/delete", response_model=bool)
  273. async def delete_tools_by_id(
  274. request: Request, id: str, user=Depends(get_verified_user)
  275. ):
  276. tools = Tools.get_tool_by_id(id)
  277. if not tools:
  278. raise HTTPException(
  279. status_code=status.HTTP_401_UNAUTHORIZED,
  280. detail=ERROR_MESSAGES.NOT_FOUND,
  281. )
  282. if (
  283. tools.user_id != user.id
  284. and not has_access(user.id, "write", tools.access_control)
  285. and user.role != "admin"
  286. ):
  287. raise HTTPException(
  288. status_code=status.HTTP_401_UNAUTHORIZED,
  289. detail=ERROR_MESSAGES.UNAUTHORIZED,
  290. )
  291. result = Tools.delete_tool_by_id(id)
  292. if result:
  293. TOOLS = request.app.state.TOOLS
  294. if id in TOOLS:
  295. del TOOLS[id]
  296. return result
  297. ############################
  298. # GetToolValves
  299. ############################
  300. @router.get("/id/{id}/valves", response_model=Optional[dict])
  301. async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
  302. tools = Tools.get_tool_by_id(id)
  303. if tools:
  304. try:
  305. valves = Tools.get_tool_valves_by_id(id)
  306. return valves
  307. except Exception as e:
  308. raise HTTPException(
  309. status_code=status.HTTP_400_BAD_REQUEST,
  310. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  311. )
  312. else:
  313. raise HTTPException(
  314. status_code=status.HTTP_401_UNAUTHORIZED,
  315. detail=ERROR_MESSAGES.NOT_FOUND,
  316. )
  317. ############################
  318. # GetToolValvesSpec
  319. ############################
  320. @router.get("/id/{id}/valves/spec", response_model=Optional[dict])
  321. async def get_tools_valves_spec_by_id(
  322. request: Request, id: str, user=Depends(get_verified_user)
  323. ):
  324. tools = Tools.get_tool_by_id(id)
  325. if tools:
  326. if id in request.app.state.TOOLS:
  327. tools_module = request.app.state.TOOLS[id]
  328. else:
  329. tools_module, _ = load_tool_module_by_id(id)
  330. request.app.state.TOOLS[id] = tools_module
  331. if hasattr(tools_module, "Valves"):
  332. Valves = tools_module.Valves
  333. return Valves.schema()
  334. return None
  335. else:
  336. raise HTTPException(
  337. status_code=status.HTTP_401_UNAUTHORIZED,
  338. detail=ERROR_MESSAGES.NOT_FOUND,
  339. )
  340. ############################
  341. # UpdateToolValves
  342. ############################
  343. @router.post("/id/{id}/valves/update", response_model=Optional[dict])
  344. async def update_tools_valves_by_id(
  345. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  346. ):
  347. tools = Tools.get_tool_by_id(id)
  348. if not tools:
  349. raise HTTPException(
  350. status_code=status.HTTP_401_UNAUTHORIZED,
  351. detail=ERROR_MESSAGES.NOT_FOUND,
  352. )
  353. if (
  354. tools.user_id != user.id
  355. and not has_access(user.id, "write", tools.access_control)
  356. and user.role != "admin"
  357. ):
  358. raise HTTPException(
  359. status_code=status.HTTP_400_BAD_REQUEST,
  360. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  361. )
  362. if id in request.app.state.TOOLS:
  363. tools_module = request.app.state.TOOLS[id]
  364. else:
  365. tools_module, _ = load_tool_module_by_id(id)
  366. request.app.state.TOOLS[id] = tools_module
  367. if not hasattr(tools_module, "Valves"):
  368. raise HTTPException(
  369. status_code=status.HTTP_401_UNAUTHORIZED,
  370. detail=ERROR_MESSAGES.NOT_FOUND,
  371. )
  372. Valves = tools_module.Valves
  373. try:
  374. form_data = {k: v for k, v in form_data.items() if v is not None}
  375. valves = Valves(**form_data)
  376. Tools.update_tool_valves_by_id(id, valves.model_dump())
  377. return valves.model_dump()
  378. except Exception as e:
  379. log.exception(f"Failed to update tool valves by id {id}: {e}")
  380. raise HTTPException(
  381. status_code=status.HTTP_400_BAD_REQUEST,
  382. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  383. )
  384. ############################
  385. # ToolUserValves
  386. ############################
  387. @router.get("/id/{id}/valves/user", response_model=Optional[dict])
  388. async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
  389. tools = Tools.get_tool_by_id(id)
  390. if tools:
  391. try:
  392. user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
  393. return user_valves
  394. except Exception as e:
  395. raise HTTPException(
  396. status_code=status.HTTP_400_BAD_REQUEST,
  397. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  398. )
  399. else:
  400. raise HTTPException(
  401. status_code=status.HTTP_401_UNAUTHORIZED,
  402. detail=ERROR_MESSAGES.NOT_FOUND,
  403. )
  404. @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
  405. async def get_tools_user_valves_spec_by_id(
  406. request: Request, id: str, user=Depends(get_verified_user)
  407. ):
  408. tools = Tools.get_tool_by_id(id)
  409. if tools:
  410. if id in request.app.state.TOOLS:
  411. tools_module = request.app.state.TOOLS[id]
  412. else:
  413. tools_module, _ = load_tool_module_by_id(id)
  414. request.app.state.TOOLS[id] = tools_module
  415. if hasattr(tools_module, "UserValves"):
  416. UserValves = tools_module.UserValves
  417. return UserValves.schema()
  418. return None
  419. else:
  420. raise HTTPException(
  421. status_code=status.HTTP_401_UNAUTHORIZED,
  422. detail=ERROR_MESSAGES.NOT_FOUND,
  423. )
  424. @router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
  425. async def update_tools_user_valves_by_id(
  426. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  427. ):
  428. tools = Tools.get_tool_by_id(id)
  429. if tools:
  430. if id in request.app.state.TOOLS:
  431. tools_module = request.app.state.TOOLS[id]
  432. else:
  433. tools_module, _ = load_tool_module_by_id(id)
  434. request.app.state.TOOLS[id] = tools_module
  435. if hasattr(tools_module, "UserValves"):
  436. UserValves = tools_module.UserValves
  437. try:
  438. form_data = {k: v for k, v in form_data.items() if v is not None}
  439. user_valves = UserValves(**form_data)
  440. Tools.update_user_valves_by_id_and_user_id(
  441. id, user.id, user_valves.model_dump()
  442. )
  443. return user_valves.model_dump()
  444. except Exception as e:
  445. log.exception(f"Failed to update user valves by id {id}: {e}")
  446. raise HTTPException(
  447. status_code=status.HTTP_400_BAD_REQUEST,
  448. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  449. )
  450. else:
  451. raise HTTPException(
  452. status_code=status.HTTP_401_UNAUTHORIZED,
  453. detail=ERROR_MESSAGES.NOT_FOUND,
  454. )
  455. else:
  456. raise HTTPException(
  457. status_code=status.HTTP_401_UNAUTHORIZED,
  458. detail=ERROR_MESSAGES.NOT_FOUND,
  459. )