tools.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563
  1. import logging
  2. from pathlib import Path
  3. from typing import Optional
  4. import time
  5. import re
  6. import aiohttp
  7. from pydantic import BaseModel, HttpUrl
  8. from open_webui.models.tools import (
  9. ToolForm,
  10. ToolModel,
  11. ToolResponse,
  12. ToolUserResponse,
  13. Tools,
  14. )
  15. from open_webui.utils.plugin import load_tool_module_by_id, replace_imports
  16. from open_webui.config import CACHE_DIR
  17. from open_webui.constants import ERROR_MESSAGES
  18. from fastapi import APIRouter, Depends, HTTPException, Request, status
  19. from open_webui.utils.tools import get_tool_specs
  20. from open_webui.utils.auth import get_admin_user, get_verified_user
  21. from open_webui.utils.access_control import has_access, has_permission
  22. from open_webui.env import SRC_LOG_LEVELS
  23. from open_webui.utils.tools import get_tool_servers_data
  24. log = logging.getLogger(__name__)
  25. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  26. router = APIRouter()
  27. ############################
  28. # GetTools
  29. ############################
  30. @router.get("/", response_model=list[ToolUserResponse])
  31. async def get_tools(request: Request, user=Depends(get_verified_user)):
  32. if not request.app.state.TOOL_SERVERS:
  33. # If the tool servers are not set, we need to set them
  34. # This is done only once when the server starts
  35. # This is done to avoid loading the tool servers every time
  36. request.app.state.TOOL_SERVERS = await get_tool_servers_data(
  37. request.app.state.config.TOOL_SERVER_CONNECTIONS
  38. )
  39. tools = Tools.get_tools()
  40. for server in request.app.state.TOOL_SERVERS:
  41. tools.append(
  42. ToolUserResponse(
  43. **{
  44. "id": f"server:{server['idx']}",
  45. "user_id": f"server:{server['idx']}",
  46. "name": server.get("openapi", {})
  47. .get("info", {})
  48. .get("title", "Tool Server"),
  49. "meta": {
  50. "description": server.get("openapi", {})
  51. .get("info", {})
  52. .get("description", ""),
  53. },
  54. "access_control": request.app.state.config.TOOL_SERVER_CONNECTIONS[
  55. server["idx"]
  56. ]
  57. .get("config", {})
  58. .get("access_control", None),
  59. "updated_at": int(time.time()),
  60. "created_at": int(time.time()),
  61. }
  62. )
  63. )
  64. if user.role != "admin":
  65. tools = [
  66. tool
  67. for tool in tools
  68. if tool.user_id == user.id
  69. or has_access(user.id, "read", tool.access_control)
  70. ]
  71. return tools
  72. ############################
  73. # GetToolList
  74. ############################
  75. @router.get("/list", response_model=list[ToolUserResponse])
  76. async def get_tool_list(user=Depends(get_verified_user)):
  77. if user.role == "admin":
  78. tools = Tools.get_tools()
  79. else:
  80. tools = Tools.get_tools_by_user_id(user.id, "write")
  81. return tools
  82. ############################
  83. # LoadFunctionFromLink
  84. ############################
  85. class LoadUrlForm(BaseModel):
  86. url: HttpUrl
  87. def github_url_to_raw_url(url: str) -> str:
  88. # Handle 'tree' (folder) URLs (add main.py at the end)
  89. m1 = re.match(r"https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.*)", url)
  90. if m1:
  91. org, repo, branch, path = m1.groups()
  92. return f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path.rstrip('/')}/main.py"
  93. # Handle 'blob' (file) URLs
  94. m2 = re.match(r"https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.*)", url)
  95. if m2:
  96. org, repo, branch, path = m2.groups()
  97. return (
  98. f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path}"
  99. )
  100. # No match; return as-is
  101. return url
  102. @router.post("/load/url", response_model=Optional[dict])
  103. async def load_tool_from_url(
  104. request: Request, form_data: LoadUrlForm, user=Depends(get_admin_user)
  105. ):
  106. # NOTE: This is NOT a SSRF vulnerability:
  107. # This endpoint is admin-only (see get_admin_user), meant for *trusted* internal use,
  108. # and does NOT accept untrusted user input. Access is enforced by authentication.
  109. url = str(form_data.url)
  110. if not url:
  111. raise HTTPException(status_code=400, detail="Please enter a valid URL")
  112. url = github_url_to_raw_url(url)
  113. url_parts = url.rstrip("/").split("/")
  114. file_name = url_parts[-1]
  115. tool_name = (
  116. file_name[:-3]
  117. if (
  118. file_name.endswith(".py")
  119. and (not file_name.startswith(("main.py", "index.py", "__init__.py")))
  120. )
  121. else url_parts[-2] if len(url_parts) > 1 else "function"
  122. )
  123. try:
  124. async with aiohttp.ClientSession() as session:
  125. async with session.get(
  126. url, headers={"Content-Type": "application/json"}
  127. ) as resp:
  128. if resp.status != 200:
  129. raise HTTPException(
  130. status_code=resp.status, detail="Failed to fetch the tool"
  131. )
  132. data = await resp.text()
  133. if not data:
  134. raise HTTPException(
  135. status_code=400, detail="No data received from the URL"
  136. )
  137. return {
  138. "name": tool_name,
  139. "content": data,
  140. }
  141. except Exception as e:
  142. raise HTTPException(status_code=500, detail=f"Error importing tool: {e}")
  143. ############################
  144. # ExportTools
  145. ############################
  146. @router.get("/export", response_model=list[ToolModel])
  147. async def export_tools(user=Depends(get_admin_user)):
  148. tools = Tools.get_tools()
  149. return tools
  150. ############################
  151. # CreateNewTools
  152. ############################
  153. @router.post("/create", response_model=Optional[ToolResponse])
  154. async def create_new_tools(
  155. request: Request,
  156. form_data: ToolForm,
  157. user=Depends(get_verified_user),
  158. ):
  159. if user.role != "admin" and not has_permission(
  160. user.id, "workspace.tools", request.app.state.config.USER_PERMISSIONS
  161. ):
  162. raise HTTPException(
  163. status_code=status.HTTP_401_UNAUTHORIZED,
  164. detail=ERROR_MESSAGES.UNAUTHORIZED,
  165. )
  166. if not form_data.id.isidentifier():
  167. raise HTTPException(
  168. status_code=status.HTTP_400_BAD_REQUEST,
  169. detail="Only alphanumeric characters and underscores are allowed in the id",
  170. )
  171. form_data.id = form_data.id.lower()
  172. tools = Tools.get_tool_by_id(form_data.id)
  173. if tools is None:
  174. try:
  175. form_data.content = replace_imports(form_data.content)
  176. tool_module, frontmatter = load_tool_module_by_id(
  177. form_data.id, content=form_data.content
  178. )
  179. form_data.meta.manifest = frontmatter
  180. TOOLS = request.app.state.TOOLS
  181. TOOLS[form_data.id] = tool_module
  182. specs = get_tool_specs(TOOLS[form_data.id])
  183. tools = Tools.insert_new_tool(user.id, form_data, specs)
  184. tool_cache_dir = CACHE_DIR / "tools" / form_data.id
  185. tool_cache_dir.mkdir(parents=True, exist_ok=True)
  186. if tools:
  187. return tools
  188. else:
  189. raise HTTPException(
  190. status_code=status.HTTP_400_BAD_REQUEST,
  191. detail=ERROR_MESSAGES.DEFAULT("Error creating tools"),
  192. )
  193. except Exception as e:
  194. log.exception(f"Failed to load the tool by id {form_data.id}: {e}")
  195. raise HTTPException(
  196. status_code=status.HTTP_400_BAD_REQUEST,
  197. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  198. )
  199. else:
  200. raise HTTPException(
  201. status_code=status.HTTP_400_BAD_REQUEST,
  202. detail=ERROR_MESSAGES.ID_TAKEN,
  203. )
  204. ############################
  205. # GetToolsById
  206. ############################
  207. @router.get("/id/{id}", response_model=Optional[ToolModel])
  208. async def get_tools_by_id(id: str, user=Depends(get_verified_user)):
  209. tools = Tools.get_tool_by_id(id)
  210. if tools:
  211. if (
  212. user.role == "admin"
  213. or tools.user_id == user.id
  214. or has_access(user.id, "read", tools.access_control)
  215. ):
  216. return tools
  217. else:
  218. raise HTTPException(
  219. status_code=status.HTTP_401_UNAUTHORIZED,
  220. detail=ERROR_MESSAGES.NOT_FOUND,
  221. )
  222. ############################
  223. # UpdateToolsById
  224. ############################
  225. @router.post("/id/{id}/update", response_model=Optional[ToolModel])
  226. async def update_tools_by_id(
  227. request: Request,
  228. id: str,
  229. form_data: ToolForm,
  230. user=Depends(get_verified_user),
  231. ):
  232. tools = Tools.get_tool_by_id(id)
  233. if not tools:
  234. raise HTTPException(
  235. status_code=status.HTTP_401_UNAUTHORIZED,
  236. detail=ERROR_MESSAGES.NOT_FOUND,
  237. )
  238. # Is the user the original creator, in a group with write access, or an admin
  239. if (
  240. tools.user_id != user.id
  241. and not has_access(user.id, "write", tools.access_control)
  242. and user.role != "admin"
  243. ):
  244. raise HTTPException(
  245. status_code=status.HTTP_401_UNAUTHORIZED,
  246. detail=ERROR_MESSAGES.UNAUTHORIZED,
  247. )
  248. try:
  249. form_data.content = replace_imports(form_data.content)
  250. tool_module, frontmatter = load_tool_module_by_id(id, content=form_data.content)
  251. form_data.meta.manifest = frontmatter
  252. TOOLS = request.app.state.TOOLS
  253. TOOLS[id] = tool_module
  254. specs = get_tool_specs(TOOLS[id])
  255. updated = {
  256. **form_data.model_dump(exclude={"id"}),
  257. "specs": specs,
  258. }
  259. log.debug(updated)
  260. tools = Tools.update_tool_by_id(id, updated)
  261. if tools:
  262. return tools
  263. else:
  264. raise HTTPException(
  265. status_code=status.HTTP_400_BAD_REQUEST,
  266. detail=ERROR_MESSAGES.DEFAULT("Error updating tools"),
  267. )
  268. except Exception as e:
  269. raise HTTPException(
  270. status_code=status.HTTP_400_BAD_REQUEST,
  271. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  272. )
  273. ############################
  274. # DeleteToolsById
  275. ############################
  276. @router.delete("/id/{id}/delete", response_model=bool)
  277. async def delete_tools_by_id(
  278. request: Request, id: str, user=Depends(get_verified_user)
  279. ):
  280. tools = Tools.get_tool_by_id(id)
  281. if not tools:
  282. raise HTTPException(
  283. status_code=status.HTTP_401_UNAUTHORIZED,
  284. detail=ERROR_MESSAGES.NOT_FOUND,
  285. )
  286. if (
  287. tools.user_id != user.id
  288. and not has_access(user.id, "write", tools.access_control)
  289. and user.role != "admin"
  290. ):
  291. raise HTTPException(
  292. status_code=status.HTTP_401_UNAUTHORIZED,
  293. detail=ERROR_MESSAGES.UNAUTHORIZED,
  294. )
  295. result = Tools.delete_tool_by_id(id)
  296. if result:
  297. TOOLS = request.app.state.TOOLS
  298. if id in TOOLS:
  299. del TOOLS[id]
  300. return result
  301. ############################
  302. # GetToolValves
  303. ############################
  304. @router.get("/id/{id}/valves", response_model=Optional[dict])
  305. async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
  306. tools = Tools.get_tool_by_id(id)
  307. if tools:
  308. try:
  309. valves = Tools.get_tool_valves_by_id(id)
  310. return valves
  311. except Exception as e:
  312. raise HTTPException(
  313. status_code=status.HTTP_400_BAD_REQUEST,
  314. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  315. )
  316. else:
  317. raise HTTPException(
  318. status_code=status.HTTP_401_UNAUTHORIZED,
  319. detail=ERROR_MESSAGES.NOT_FOUND,
  320. )
  321. ############################
  322. # GetToolValvesSpec
  323. ############################
  324. @router.get("/id/{id}/valves/spec", response_model=Optional[dict])
  325. async def get_tools_valves_spec_by_id(
  326. request: Request, id: str, user=Depends(get_verified_user)
  327. ):
  328. tools = Tools.get_tool_by_id(id)
  329. if tools:
  330. if id in request.app.state.TOOLS:
  331. tools_module = request.app.state.TOOLS[id]
  332. else:
  333. tools_module, _ = load_tool_module_by_id(id)
  334. request.app.state.TOOLS[id] = tools_module
  335. if hasattr(tools_module, "Valves"):
  336. Valves = tools_module.Valves
  337. return Valves.schema()
  338. return None
  339. else:
  340. raise HTTPException(
  341. status_code=status.HTTP_401_UNAUTHORIZED,
  342. detail=ERROR_MESSAGES.NOT_FOUND,
  343. )
  344. ############################
  345. # UpdateToolValves
  346. ############################
  347. @router.post("/id/{id}/valves/update", response_model=Optional[dict])
  348. async def update_tools_valves_by_id(
  349. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  350. ):
  351. tools = Tools.get_tool_by_id(id)
  352. if not tools:
  353. raise HTTPException(
  354. status_code=status.HTTP_401_UNAUTHORIZED,
  355. detail=ERROR_MESSAGES.NOT_FOUND,
  356. )
  357. if (
  358. tools.user_id != user.id
  359. and not has_access(user.id, "write", tools.access_control)
  360. and user.role != "admin"
  361. ):
  362. raise HTTPException(
  363. status_code=status.HTTP_400_BAD_REQUEST,
  364. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  365. )
  366. if id in request.app.state.TOOLS:
  367. tools_module = request.app.state.TOOLS[id]
  368. else:
  369. tools_module, _ = load_tool_module_by_id(id)
  370. request.app.state.TOOLS[id] = tools_module
  371. if not hasattr(tools_module, "Valves"):
  372. raise HTTPException(
  373. status_code=status.HTTP_401_UNAUTHORIZED,
  374. detail=ERROR_MESSAGES.NOT_FOUND,
  375. )
  376. Valves = tools_module.Valves
  377. try:
  378. form_data = {k: v for k, v in form_data.items() if v is not None}
  379. valves = Valves(**form_data)
  380. Tools.update_tool_valves_by_id(id, valves.model_dump())
  381. return valves.model_dump()
  382. except Exception as e:
  383. log.exception(f"Failed to update tool valves by id {id}: {e}")
  384. raise HTTPException(
  385. status_code=status.HTTP_400_BAD_REQUEST,
  386. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  387. )
  388. ############################
  389. # ToolUserValves
  390. ############################
  391. @router.get("/id/{id}/valves/user", response_model=Optional[dict])
  392. async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
  393. tools = Tools.get_tool_by_id(id)
  394. if tools:
  395. try:
  396. user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
  397. return user_valves
  398. except Exception as e:
  399. raise HTTPException(
  400. status_code=status.HTTP_400_BAD_REQUEST,
  401. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  402. )
  403. else:
  404. raise HTTPException(
  405. status_code=status.HTTP_401_UNAUTHORIZED,
  406. detail=ERROR_MESSAGES.NOT_FOUND,
  407. )
  408. @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
  409. async def get_tools_user_valves_spec_by_id(
  410. request: Request, id: str, user=Depends(get_verified_user)
  411. ):
  412. tools = Tools.get_tool_by_id(id)
  413. if tools:
  414. if id in request.app.state.TOOLS:
  415. tools_module = request.app.state.TOOLS[id]
  416. else:
  417. tools_module, _ = load_tool_module_by_id(id)
  418. request.app.state.TOOLS[id] = tools_module
  419. if hasattr(tools_module, "UserValves"):
  420. UserValves = tools_module.UserValves
  421. return UserValves.schema()
  422. return None
  423. else:
  424. raise HTTPException(
  425. status_code=status.HTTP_401_UNAUTHORIZED,
  426. detail=ERROR_MESSAGES.NOT_FOUND,
  427. )
  428. @router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
  429. async def update_tools_user_valves_by_id(
  430. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  431. ):
  432. tools = Tools.get_tool_by_id(id)
  433. if tools:
  434. if id in request.app.state.TOOLS:
  435. tools_module = request.app.state.TOOLS[id]
  436. else:
  437. tools_module, _ = load_tool_module_by_id(id)
  438. request.app.state.TOOLS[id] = tools_module
  439. if hasattr(tools_module, "UserValves"):
  440. UserValves = tools_module.UserValves
  441. try:
  442. form_data = {k: v for k, v in form_data.items() if v is not None}
  443. user_valves = UserValves(**form_data)
  444. Tools.update_user_valves_by_id_and_user_id(
  445. id, user.id, user_valves.model_dump()
  446. )
  447. return user_valves.model_dump()
  448. except Exception as e:
  449. log.exception(f"Failed to update user valves by id {id}: {e}")
  450. raise HTTPException(
  451. status_code=status.HTTP_400_BAD_REQUEST,
  452. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  453. )
  454. else:
  455. raise HTTPException(
  456. status_code=status.HTTP_401_UNAUTHORIZED,
  457. detail=ERROR_MESSAGES.NOT_FOUND,
  458. )
  459. else:
  460. raise HTTPException(
  461. status_code=status.HTTP_401_UNAUTHORIZED,
  462. detail=ERROR_MESSAGES.NOT_FOUND,
  463. )