tools.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486
  1. import logging
  2. from pathlib import Path
  3. from typing import Optional
  4. import time
  5. from open_webui.models.tools import (
  6. ToolForm,
  7. ToolModel,
  8. ToolResponse,
  9. ToolUserResponse,
  10. Tools,
  11. )
  12. from open_webui.utils.plugin import load_tools_module_by_id, replace_imports
  13. from open_webui.config import CACHE_DIR
  14. from open_webui.constants import ERROR_MESSAGES
  15. from fastapi import APIRouter, Depends, HTTPException, Request, status
  16. from open_webui.utils.tools import get_tools_specs
  17. from open_webui.utils.auth import get_admin_user, get_verified_user
  18. from open_webui.utils.access_control import has_access, has_permission
  19. from open_webui.env import SRC_LOG_LEVELS
  20. from open_webui.utils.tools import get_tool_servers_data
  21. log = logging.getLogger(__name__)
  22. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  23. router = APIRouter()
  24. ############################
  25. # GetTools
  26. ############################
  27. @router.get("/", response_model=list[ToolUserResponse])
  28. async def get_tools(request: Request, user=Depends(get_verified_user)):
  29. if not request.app.state.TOOL_SERVERS:
  30. # If the tool servers are not set, we need to set them
  31. # This is done only once when the server starts
  32. # This is done to avoid loading the tool servers every time
  33. request.app.state.TOOL_SERVERS = await get_tool_servers_data(
  34. request.app.state.config.TOOL_SERVER_CONNECTIONS
  35. )
  36. tools = Tools.get_tools()
  37. for idx, server in enumerate(request.app.state.TOOL_SERVERS):
  38. tools.append(
  39. ToolUserResponse(
  40. **{
  41. "id": f"server:{server['idx']}",
  42. "user_id": f"server:{server['idx']}",
  43. "name": server["openapi"]
  44. .get("info", {})
  45. .get("title", "Tool Server"),
  46. "meta": {
  47. "description": server["openapi"]
  48. .get("info", {})
  49. .get("description", ""),
  50. },
  51. "access_control": request.app.state.config.TOOL_SERVER_CONNECTIONS[
  52. idx
  53. ]
  54. .get("config", {})
  55. .get("access_control", None),
  56. "updated_at": int(time.time()),
  57. "created_at": int(time.time()),
  58. }
  59. )
  60. )
  61. if user.role != "admin":
  62. tools = [
  63. tool
  64. for tool in tools
  65. if tool.user_id == user.id
  66. or has_access(user.id, "read", tool.access_control)
  67. ]
  68. return tools
  69. ############################
  70. # GetToolList
  71. ############################
  72. @router.get("/list", response_model=list[ToolUserResponse])
  73. async def get_tool_list(user=Depends(get_verified_user)):
  74. if user.role == "admin":
  75. tools = Tools.get_tools()
  76. else:
  77. tools = Tools.get_tools_by_user_id(user.id, "write")
  78. return tools
  79. ############################
  80. # ExportTools
  81. ############################
  82. @router.get("/export", response_model=list[ToolModel])
  83. async def export_tools(user=Depends(get_admin_user)):
  84. tools = Tools.get_tools()
  85. return tools
  86. ############################
  87. # CreateNewTools
  88. ############################
  89. @router.post("/create", response_model=Optional[ToolResponse])
  90. async def create_new_tools(
  91. request: Request,
  92. form_data: ToolForm,
  93. user=Depends(get_verified_user),
  94. ):
  95. if user.role != "admin" and not has_permission(
  96. user.id, "workspace.tools", request.app.state.config.USER_PERMISSIONS
  97. ):
  98. raise HTTPException(
  99. status_code=status.HTTP_401_UNAUTHORIZED,
  100. detail=ERROR_MESSAGES.UNAUTHORIZED,
  101. )
  102. if not form_data.id.isidentifier():
  103. raise HTTPException(
  104. status_code=status.HTTP_400_BAD_REQUEST,
  105. detail="Only alphanumeric characters and underscores are allowed in the id",
  106. )
  107. form_data.id = form_data.id.lower()
  108. tools = Tools.get_tool_by_id(form_data.id)
  109. if tools is None:
  110. try:
  111. form_data.content = replace_imports(form_data.content)
  112. tools_module, frontmatter = load_tools_module_by_id(
  113. form_data.id, content=form_data.content
  114. )
  115. form_data.meta.manifest = frontmatter
  116. TOOLS = request.app.state.TOOLS
  117. TOOLS[form_data.id] = tools_module
  118. specs = get_tools_specs(TOOLS[form_data.id])
  119. tools = Tools.insert_new_tool(user.id, form_data, specs)
  120. tool_cache_dir = CACHE_DIR / "tools" / form_data.id
  121. tool_cache_dir.mkdir(parents=True, exist_ok=True)
  122. if tools:
  123. return tools
  124. else:
  125. raise HTTPException(
  126. status_code=status.HTTP_400_BAD_REQUEST,
  127. detail=ERROR_MESSAGES.DEFAULT("Error creating tools"),
  128. )
  129. except Exception as e:
  130. log.exception(f"Failed to load the tool by id {form_data.id}: {e}")
  131. raise HTTPException(
  132. status_code=status.HTTP_400_BAD_REQUEST,
  133. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  134. )
  135. else:
  136. raise HTTPException(
  137. status_code=status.HTTP_400_BAD_REQUEST,
  138. detail=ERROR_MESSAGES.ID_TAKEN,
  139. )
  140. ############################
  141. # GetToolsById
  142. ############################
  143. @router.get("/id/{id}", response_model=Optional[ToolModel])
  144. async def get_tools_by_id(id: str, user=Depends(get_verified_user)):
  145. tools = Tools.get_tool_by_id(id)
  146. if tools:
  147. if (
  148. user.role == "admin"
  149. or tools.user_id == user.id
  150. or has_access(user.id, "read", tools.access_control)
  151. ):
  152. return tools
  153. else:
  154. raise HTTPException(
  155. status_code=status.HTTP_401_UNAUTHORIZED,
  156. detail=ERROR_MESSAGES.NOT_FOUND,
  157. )
  158. ############################
  159. # UpdateToolsById
  160. ############################
  161. @router.post("/id/{id}/update", response_model=Optional[ToolModel])
  162. async def update_tools_by_id(
  163. request: Request,
  164. id: str,
  165. form_data: ToolForm,
  166. user=Depends(get_verified_user),
  167. ):
  168. tools = Tools.get_tool_by_id(id)
  169. if not tools:
  170. raise HTTPException(
  171. status_code=status.HTTP_401_UNAUTHORIZED,
  172. detail=ERROR_MESSAGES.NOT_FOUND,
  173. )
  174. # Is the user the original creator, in a group with write access, or an admin
  175. if (
  176. tools.user_id != user.id
  177. and not has_access(user.id, "write", tools.access_control)
  178. and user.role != "admin"
  179. ):
  180. raise HTTPException(
  181. status_code=status.HTTP_401_UNAUTHORIZED,
  182. detail=ERROR_MESSAGES.UNAUTHORIZED,
  183. )
  184. try:
  185. form_data.content = replace_imports(form_data.content)
  186. tools_module, frontmatter = load_tools_module_by_id(
  187. id, content=form_data.content
  188. )
  189. form_data.meta.manifest = frontmatter
  190. TOOLS = request.app.state.TOOLS
  191. TOOLS[id] = tools_module
  192. specs = get_tools_specs(TOOLS[id])
  193. updated = {
  194. **form_data.model_dump(exclude={"id"}),
  195. "specs": specs,
  196. }
  197. log.debug(updated)
  198. tools = Tools.update_tool_by_id(id, updated)
  199. if tools:
  200. return tools
  201. else:
  202. raise HTTPException(
  203. status_code=status.HTTP_400_BAD_REQUEST,
  204. detail=ERROR_MESSAGES.DEFAULT("Error updating tools"),
  205. )
  206. except Exception as e:
  207. raise HTTPException(
  208. status_code=status.HTTP_400_BAD_REQUEST,
  209. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  210. )
  211. ############################
  212. # DeleteToolsById
  213. ############################
  214. @router.delete("/id/{id}/delete", response_model=bool)
  215. async def delete_tools_by_id(
  216. request: Request, id: str, user=Depends(get_verified_user)
  217. ):
  218. tools = Tools.get_tool_by_id(id)
  219. if not tools:
  220. raise HTTPException(
  221. status_code=status.HTTP_401_UNAUTHORIZED,
  222. detail=ERROR_MESSAGES.NOT_FOUND,
  223. )
  224. if (
  225. tools.user_id != user.id
  226. and not has_access(user.id, "write", tools.access_control)
  227. and user.role != "admin"
  228. ):
  229. raise HTTPException(
  230. status_code=status.HTTP_401_UNAUTHORIZED,
  231. detail=ERROR_MESSAGES.UNAUTHORIZED,
  232. )
  233. result = Tools.delete_tool_by_id(id)
  234. if result:
  235. TOOLS = request.app.state.TOOLS
  236. if id in TOOLS:
  237. del TOOLS[id]
  238. return result
  239. ############################
  240. # GetToolValves
  241. ############################
  242. @router.get("/id/{id}/valves", response_model=Optional[dict])
  243. async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
  244. tools = Tools.get_tool_by_id(id)
  245. if tools:
  246. try:
  247. valves = Tools.get_tool_valves_by_id(id)
  248. return valves
  249. except Exception as e:
  250. raise HTTPException(
  251. status_code=status.HTTP_400_BAD_REQUEST,
  252. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  253. )
  254. else:
  255. raise HTTPException(
  256. status_code=status.HTTP_401_UNAUTHORIZED,
  257. detail=ERROR_MESSAGES.NOT_FOUND,
  258. )
  259. ############################
  260. # GetToolValvesSpec
  261. ############################
  262. @router.get("/id/{id}/valves/spec", response_model=Optional[dict])
  263. async def get_tools_valves_spec_by_id(
  264. request: Request, id: str, user=Depends(get_verified_user)
  265. ):
  266. tools = Tools.get_tool_by_id(id)
  267. if tools:
  268. if id in request.app.state.TOOLS:
  269. tools_module = request.app.state.TOOLS[id]
  270. else:
  271. tools_module, _ = load_tools_module_by_id(id)
  272. request.app.state.TOOLS[id] = tools_module
  273. if hasattr(tools_module, "Valves"):
  274. Valves = tools_module.Valves
  275. return Valves.schema()
  276. return None
  277. else:
  278. raise HTTPException(
  279. status_code=status.HTTP_401_UNAUTHORIZED,
  280. detail=ERROR_MESSAGES.NOT_FOUND,
  281. )
  282. ############################
  283. # UpdateToolValves
  284. ############################
  285. @router.post("/id/{id}/valves/update", response_model=Optional[dict])
  286. async def update_tools_valves_by_id(
  287. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  288. ):
  289. tools = Tools.get_tool_by_id(id)
  290. if not tools:
  291. raise HTTPException(
  292. status_code=status.HTTP_401_UNAUTHORIZED,
  293. detail=ERROR_MESSAGES.NOT_FOUND,
  294. )
  295. if (
  296. tools.user_id != user.id
  297. and not has_access(user.id, "write", tools.access_control)
  298. and user.role != "admin"
  299. ):
  300. raise HTTPException(
  301. status_code=status.HTTP_400_BAD_REQUEST,
  302. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  303. )
  304. if id in request.app.state.TOOLS:
  305. tools_module = request.app.state.TOOLS[id]
  306. else:
  307. tools_module, _ = load_tools_module_by_id(id)
  308. request.app.state.TOOLS[id] = tools_module
  309. if not hasattr(tools_module, "Valves"):
  310. raise HTTPException(
  311. status_code=status.HTTP_401_UNAUTHORIZED,
  312. detail=ERROR_MESSAGES.NOT_FOUND,
  313. )
  314. Valves = tools_module.Valves
  315. try:
  316. form_data = {k: v for k, v in form_data.items() if v is not None}
  317. valves = Valves(**form_data)
  318. Tools.update_tool_valves_by_id(id, valves.model_dump())
  319. return valves.model_dump()
  320. except Exception as e:
  321. log.exception(f"Failed to update tool valves by id {id}: {e}")
  322. raise HTTPException(
  323. status_code=status.HTTP_400_BAD_REQUEST,
  324. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  325. )
  326. ############################
  327. # ToolUserValves
  328. ############################
  329. @router.get("/id/{id}/valves/user", response_model=Optional[dict])
  330. async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
  331. tools = Tools.get_tool_by_id(id)
  332. if tools:
  333. try:
  334. user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
  335. return user_valves
  336. except Exception as e:
  337. raise HTTPException(
  338. status_code=status.HTTP_400_BAD_REQUEST,
  339. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  340. )
  341. else:
  342. raise HTTPException(
  343. status_code=status.HTTP_401_UNAUTHORIZED,
  344. detail=ERROR_MESSAGES.NOT_FOUND,
  345. )
  346. @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
  347. async def get_tools_user_valves_spec_by_id(
  348. request: Request, id: str, user=Depends(get_verified_user)
  349. ):
  350. tools = Tools.get_tool_by_id(id)
  351. if tools:
  352. if id in request.app.state.TOOLS:
  353. tools_module = request.app.state.TOOLS[id]
  354. else:
  355. tools_module, _ = load_tools_module_by_id(id)
  356. request.app.state.TOOLS[id] = tools_module
  357. if hasattr(tools_module, "UserValves"):
  358. UserValves = tools_module.UserValves
  359. return UserValves.schema()
  360. return None
  361. else:
  362. raise HTTPException(
  363. status_code=status.HTTP_401_UNAUTHORIZED,
  364. detail=ERROR_MESSAGES.NOT_FOUND,
  365. )
  366. @router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
  367. async def update_tools_user_valves_by_id(
  368. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  369. ):
  370. tools = Tools.get_tool_by_id(id)
  371. if tools:
  372. if id in request.app.state.TOOLS:
  373. tools_module = request.app.state.TOOLS[id]
  374. else:
  375. tools_module, _ = load_tools_module_by_id(id)
  376. request.app.state.TOOLS[id] = tools_module
  377. if hasattr(tools_module, "UserValves"):
  378. UserValves = tools_module.UserValves
  379. try:
  380. form_data = {k: v for k, v in form_data.items() if v is not None}
  381. user_valves = UserValves(**form_data)
  382. Tools.update_user_valves_by_id_and_user_id(
  383. id, user.id, user_valves.model_dump()
  384. )
  385. return user_valves.model_dump()
  386. except Exception as e:
  387. log.exception(f"Failed to update user valves by id {id}: {e}")
  388. raise HTTPException(
  389. status_code=status.HTTP_400_BAD_REQUEST,
  390. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  391. )
  392. else:
  393. raise HTTPException(
  394. status_code=status.HTTP_401_UNAUTHORIZED,
  395. detail=ERROR_MESSAGES.NOT_FOUND,
  396. )
  397. else:
  398. raise HTTPException(
  399. status_code=status.HTTP_401_UNAUTHORIZED,
  400. detail=ERROR_MESSAGES.NOT_FOUND,
  401. )