tools.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484
  1. import logging
  2. from pathlib import Path
  3. from typing import Optional
  4. import time
  5. from open_webui.models.tools import (
  6. ToolForm,
  7. ToolModel,
  8. ToolResponse,
  9. ToolUserResponse,
  10. Tools,
  11. )
  12. from open_webui.utils.plugin import load_tool_module_by_id, replace_imports
  13. from open_webui.config import CACHE_DIR
  14. from open_webui.constants import ERROR_MESSAGES
  15. from fastapi import APIRouter, Depends, HTTPException, Request, status
  16. from open_webui.utils.tools import get_tool_specs
  17. from open_webui.utils.auth import get_admin_user, get_verified_user
  18. from open_webui.utils.access_control import has_access, has_permission
  19. from open_webui.env import SRC_LOG_LEVELS
  20. from open_webui.utils.tools import get_tool_servers_data
  21. log = logging.getLogger(__name__)
  22. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  23. router = APIRouter()
  24. ############################
  25. # GetTools
  26. ############################
  27. @router.get("/", response_model=list[ToolUserResponse])
  28. async def get_tools(request: Request, user=Depends(get_verified_user)):
  29. if not request.app.state.TOOL_SERVERS:
  30. # If the tool servers are not set, we need to set them
  31. # This is done only once when the server starts
  32. # This is done to avoid loading the tool servers every time
  33. request.app.state.TOOL_SERVERS = await get_tool_servers_data(
  34. request.app.state.config.TOOL_SERVER_CONNECTIONS
  35. )
  36. tools = Tools.get_tools()
  37. for server in request.app.state.TOOL_SERVERS:
  38. tools.append(
  39. ToolUserResponse(
  40. **{
  41. "id": f"server:{server['idx']}",
  42. "user_id": f"server:{server['idx']}",
  43. "name": server["openapi"]
  44. .get("info", {})
  45. .get("title", "Tool Server"),
  46. "meta": {
  47. "description": server["openapi"]
  48. .get("info", {})
  49. .get("description", ""),
  50. },
  51. "access_control": request.app.state.config.TOOL_SERVER_CONNECTIONS[
  52. server["idx"]
  53. ]
  54. .get("config", {})
  55. .get("access_control", None),
  56. "updated_at": int(time.time()),
  57. "created_at": int(time.time()),
  58. }
  59. )
  60. )
  61. if user.role != "admin":
  62. tools = [
  63. tool
  64. for tool in tools
  65. if tool.user_id == user.id
  66. or has_access(user.id, "read", tool.access_control)
  67. ]
  68. return tools
  69. ############################
  70. # GetToolList
  71. ############################
  72. @router.get("/list", response_model=list[ToolUserResponse])
  73. async def get_tool_list(user=Depends(get_verified_user)):
  74. if user.role == "admin":
  75. tools = Tools.get_tools()
  76. else:
  77. tools = Tools.get_tools_by_user_id(user.id, "write")
  78. return tools
  79. ############################
  80. # ExportTools
  81. ############################
  82. @router.get("/export", response_model=list[ToolModel])
  83. async def export_tools(user=Depends(get_admin_user)):
  84. tools = Tools.get_tools()
  85. return tools
  86. ############################
  87. # CreateNewTools
  88. ############################
  89. @router.post("/create", response_model=Optional[ToolResponse])
  90. async def create_new_tools(
  91. request: Request,
  92. form_data: ToolForm,
  93. user=Depends(get_verified_user),
  94. ):
  95. if user.role != "admin" and not has_permission(
  96. user.id, "workspace.tools", request.app.state.config.USER_PERMISSIONS
  97. ):
  98. raise HTTPException(
  99. status_code=status.HTTP_401_UNAUTHORIZED,
  100. detail=ERROR_MESSAGES.UNAUTHORIZED,
  101. )
  102. if not form_data.id.isidentifier():
  103. raise HTTPException(
  104. status_code=status.HTTP_400_BAD_REQUEST,
  105. detail="Only alphanumeric characters and underscores are allowed in the id",
  106. )
  107. form_data.id = form_data.id.lower()
  108. tools = Tools.get_tool_by_id(form_data.id)
  109. if tools is None:
  110. try:
  111. form_data.content = replace_imports(form_data.content)
  112. tool_module, frontmatter = load_tool_module_by_id(
  113. form_data.id, content=form_data.content
  114. )
  115. form_data.meta.manifest = frontmatter
  116. TOOLS = request.app.state.TOOLS
  117. TOOLS[form_data.id] = tool_module
  118. specs = get_tool_specs(TOOLS[form_data.id])
  119. tools = Tools.insert_new_tool(user.id, form_data, specs)
  120. tool_cache_dir = CACHE_DIR / "tools" / form_data.id
  121. tool_cache_dir.mkdir(parents=True, exist_ok=True)
  122. if tools:
  123. return tools
  124. else:
  125. raise HTTPException(
  126. status_code=status.HTTP_400_BAD_REQUEST,
  127. detail=ERROR_MESSAGES.DEFAULT("Error creating tools"),
  128. )
  129. except Exception as e:
  130. log.exception(f"Failed to load the tool by id {form_data.id}: {e}")
  131. raise HTTPException(
  132. status_code=status.HTTP_400_BAD_REQUEST,
  133. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  134. )
  135. else:
  136. raise HTTPException(
  137. status_code=status.HTTP_400_BAD_REQUEST,
  138. detail=ERROR_MESSAGES.ID_TAKEN,
  139. )
  140. ############################
  141. # GetToolsById
  142. ############################
  143. @router.get("/id/{id}", response_model=Optional[ToolModel])
  144. async def get_tools_by_id(id: str, user=Depends(get_verified_user)):
  145. tools = Tools.get_tool_by_id(id)
  146. if tools:
  147. if (
  148. user.role == "admin"
  149. or tools.user_id == user.id
  150. or has_access(user.id, "read", tools.access_control)
  151. ):
  152. return tools
  153. else:
  154. raise HTTPException(
  155. status_code=status.HTTP_401_UNAUTHORIZED,
  156. detail=ERROR_MESSAGES.NOT_FOUND,
  157. )
  158. ############################
  159. # UpdateToolsById
  160. ############################
  161. @router.post("/id/{id}/update", response_model=Optional[ToolModel])
  162. async def update_tools_by_id(
  163. request: Request,
  164. id: str,
  165. form_data: ToolForm,
  166. user=Depends(get_verified_user),
  167. ):
  168. tools = Tools.get_tool_by_id(id)
  169. if not tools:
  170. raise HTTPException(
  171. status_code=status.HTTP_401_UNAUTHORIZED,
  172. detail=ERROR_MESSAGES.NOT_FOUND,
  173. )
  174. # Is the user the original creator, in a group with write access, or an admin
  175. if (
  176. tools.user_id != user.id
  177. and not has_access(user.id, "write", tools.access_control)
  178. and user.role != "admin"
  179. ):
  180. raise HTTPException(
  181. status_code=status.HTTP_401_UNAUTHORIZED,
  182. detail=ERROR_MESSAGES.UNAUTHORIZED,
  183. )
  184. try:
  185. form_data.content = replace_imports(form_data.content)
  186. tool_module, frontmatter = load_tool_module_by_id(id, content=form_data.content)
  187. form_data.meta.manifest = frontmatter
  188. TOOLS = request.app.state.TOOLS
  189. TOOLS[id] = tool_module
  190. specs = get_tool_specs(TOOLS[id])
  191. updated = {
  192. **form_data.model_dump(exclude={"id"}),
  193. "specs": specs,
  194. }
  195. log.debug(updated)
  196. tools = Tools.update_tool_by_id(id, updated)
  197. if tools:
  198. return tools
  199. else:
  200. raise HTTPException(
  201. status_code=status.HTTP_400_BAD_REQUEST,
  202. detail=ERROR_MESSAGES.DEFAULT("Error updating tools"),
  203. )
  204. except Exception as e:
  205. raise HTTPException(
  206. status_code=status.HTTP_400_BAD_REQUEST,
  207. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  208. )
  209. ############################
  210. # DeleteToolsById
  211. ############################
  212. @router.delete("/id/{id}/delete", response_model=bool)
  213. async def delete_tools_by_id(
  214. request: Request, id: str, user=Depends(get_verified_user)
  215. ):
  216. tools = Tools.get_tool_by_id(id)
  217. if not tools:
  218. raise HTTPException(
  219. status_code=status.HTTP_401_UNAUTHORIZED,
  220. detail=ERROR_MESSAGES.NOT_FOUND,
  221. )
  222. if (
  223. tools.user_id != user.id
  224. and not has_access(user.id, "write", tools.access_control)
  225. and user.role != "admin"
  226. ):
  227. raise HTTPException(
  228. status_code=status.HTTP_401_UNAUTHORIZED,
  229. detail=ERROR_MESSAGES.UNAUTHORIZED,
  230. )
  231. result = Tools.delete_tool_by_id(id)
  232. if result:
  233. TOOLS = request.app.state.TOOLS
  234. if id in TOOLS:
  235. del TOOLS[id]
  236. return result
  237. ############################
  238. # GetToolValves
  239. ############################
  240. @router.get("/id/{id}/valves", response_model=Optional[dict])
  241. async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
  242. tools = Tools.get_tool_by_id(id)
  243. if tools:
  244. try:
  245. valves = Tools.get_tool_valves_by_id(id)
  246. return valves
  247. except Exception as e:
  248. raise HTTPException(
  249. status_code=status.HTTP_400_BAD_REQUEST,
  250. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  251. )
  252. else:
  253. raise HTTPException(
  254. status_code=status.HTTP_401_UNAUTHORIZED,
  255. detail=ERROR_MESSAGES.NOT_FOUND,
  256. )
  257. ############################
  258. # GetToolValvesSpec
  259. ############################
  260. @router.get("/id/{id}/valves/spec", response_model=Optional[dict])
  261. async def get_tools_valves_spec_by_id(
  262. request: Request, id: str, user=Depends(get_verified_user)
  263. ):
  264. tools = Tools.get_tool_by_id(id)
  265. if tools:
  266. if id in request.app.state.TOOLS:
  267. tools_module = request.app.state.TOOLS[id]
  268. else:
  269. tools_module, _ = load_tool_module_by_id(id)
  270. request.app.state.TOOLS[id] = tools_module
  271. if hasattr(tools_module, "Valves"):
  272. Valves = tools_module.Valves
  273. return Valves.schema()
  274. return None
  275. else:
  276. raise HTTPException(
  277. status_code=status.HTTP_401_UNAUTHORIZED,
  278. detail=ERROR_MESSAGES.NOT_FOUND,
  279. )
  280. ############################
  281. # UpdateToolValves
  282. ############################
  283. @router.post("/id/{id}/valves/update", response_model=Optional[dict])
  284. async def update_tools_valves_by_id(
  285. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  286. ):
  287. tools = Tools.get_tool_by_id(id)
  288. if not tools:
  289. raise HTTPException(
  290. status_code=status.HTTP_401_UNAUTHORIZED,
  291. detail=ERROR_MESSAGES.NOT_FOUND,
  292. )
  293. if (
  294. tools.user_id != user.id
  295. and not has_access(user.id, "write", tools.access_control)
  296. and user.role != "admin"
  297. ):
  298. raise HTTPException(
  299. status_code=status.HTTP_400_BAD_REQUEST,
  300. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  301. )
  302. if id in request.app.state.TOOLS:
  303. tools_module = request.app.state.TOOLS[id]
  304. else:
  305. tools_module, _ = load_tool_module_by_id(id)
  306. request.app.state.TOOLS[id] = tools_module
  307. if not hasattr(tools_module, "Valves"):
  308. raise HTTPException(
  309. status_code=status.HTTP_401_UNAUTHORIZED,
  310. detail=ERROR_MESSAGES.NOT_FOUND,
  311. )
  312. Valves = tools_module.Valves
  313. try:
  314. form_data = {k: v for k, v in form_data.items() if v is not None}
  315. valves = Valves(**form_data)
  316. Tools.update_tool_valves_by_id(id, valves.model_dump())
  317. return valves.model_dump()
  318. except Exception as e:
  319. log.exception(f"Failed to update tool valves by id {id}: {e}")
  320. raise HTTPException(
  321. status_code=status.HTTP_400_BAD_REQUEST,
  322. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  323. )
  324. ############################
  325. # ToolUserValves
  326. ############################
  327. @router.get("/id/{id}/valves/user", response_model=Optional[dict])
  328. async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
  329. tools = Tools.get_tool_by_id(id)
  330. if tools:
  331. try:
  332. user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
  333. return user_valves
  334. except Exception as e:
  335. raise HTTPException(
  336. status_code=status.HTTP_400_BAD_REQUEST,
  337. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  338. )
  339. else:
  340. raise HTTPException(
  341. status_code=status.HTTP_401_UNAUTHORIZED,
  342. detail=ERROR_MESSAGES.NOT_FOUND,
  343. )
  344. @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
  345. async def get_tools_user_valves_spec_by_id(
  346. request: Request, id: str, user=Depends(get_verified_user)
  347. ):
  348. tools = Tools.get_tool_by_id(id)
  349. if tools:
  350. if id in request.app.state.TOOLS:
  351. tools_module = request.app.state.TOOLS[id]
  352. else:
  353. tools_module, _ = load_tool_module_by_id(id)
  354. request.app.state.TOOLS[id] = tools_module
  355. if hasattr(tools_module, "UserValves"):
  356. UserValves = tools_module.UserValves
  357. return UserValves.schema()
  358. return None
  359. else:
  360. raise HTTPException(
  361. status_code=status.HTTP_401_UNAUTHORIZED,
  362. detail=ERROR_MESSAGES.NOT_FOUND,
  363. )
  364. @router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
  365. async def update_tools_user_valves_by_id(
  366. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  367. ):
  368. tools = Tools.get_tool_by_id(id)
  369. if tools:
  370. if id in request.app.state.TOOLS:
  371. tools_module = request.app.state.TOOLS[id]
  372. else:
  373. tools_module, _ = load_tool_module_by_id(id)
  374. request.app.state.TOOLS[id] = tools_module
  375. if hasattr(tools_module, "UserValves"):
  376. UserValves = tools_module.UserValves
  377. try:
  378. form_data = {k: v for k, v in form_data.items() if v is not None}
  379. user_valves = UserValves(**form_data)
  380. Tools.update_user_valves_by_id_and_user_id(
  381. id, user.id, user_valves.model_dump()
  382. )
  383. return user_valves.model_dump()
  384. except Exception as e:
  385. log.exception(f"Failed to update user valves by id {id}: {e}")
  386. raise HTTPException(
  387. status_code=status.HTTP_400_BAD_REQUEST,
  388. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  389. )
  390. else:
  391. raise HTTPException(
  392. status_code=status.HTTP_401_UNAUTHORIZED,
  393. detail=ERROR_MESSAGES.NOT_FOUND,
  394. )
  395. else:
  396. raise HTTPException(
  397. status_code=status.HTTP_401_UNAUTHORIZED,
  398. detail=ERROR_MESSAGES.NOT_FOUND,
  399. )