functions.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545
  1. import os
  2. import re
  3. import logging
  4. import aiohttp
  5. from pathlib import Path
  6. from typing import Optional
  7. from open_webui.models.functions import (
  8. FunctionForm,
  9. FunctionModel,
  10. FunctionResponse,
  11. FunctionUserResponse,
  12. FunctionWithValvesModel,
  13. Functions,
  14. )
  15. from open_webui.utils.plugin import (
  16. load_function_module_by_id,
  17. replace_imports,
  18. get_function_module_from_cache,
  19. )
  20. from open_webui.config import CACHE_DIR
  21. from open_webui.constants import ERROR_MESSAGES
  22. from fastapi import APIRouter, Depends, HTTPException, Request, status
  23. from open_webui.utils.auth import get_admin_user, get_verified_user
  24. from open_webui.env import SRC_LOG_LEVELS
  25. from pydantic import BaseModel, HttpUrl
  26. log = logging.getLogger(__name__)
  27. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  28. router = APIRouter()
  29. ############################
  30. # GetFunctions
  31. ############################
  32. @router.get("/", response_model=list[FunctionResponse])
  33. async def get_functions(user=Depends(get_verified_user)):
  34. return Functions.get_functions()
  35. @router.get("/list", response_model=list[FunctionUserResponse])
  36. async def get_function_list(user=Depends(get_admin_user)):
  37. return Functions.get_function_list()
  38. ############################
  39. # ExportFunctions
  40. ############################
  41. @router.get("/export", response_model=list[FunctionModel | FunctionWithValvesModel])
  42. async def get_functions(include_valves: bool = False, user=Depends(get_admin_user)):
  43. return Functions.get_functions(include_valves=include_valves)
  44. ############################
  45. # LoadFunctionFromLink
  46. ############################
  47. class LoadUrlForm(BaseModel):
  48. url: HttpUrl
  49. def github_url_to_raw_url(url: str) -> str:
  50. # Handle 'tree' (folder) URLs (add main.py at the end)
  51. m1 = re.match(r"https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.*)", url)
  52. if m1:
  53. org, repo, branch, path = m1.groups()
  54. return f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path.rstrip('/')}/main.py"
  55. # Handle 'blob' (file) URLs
  56. m2 = re.match(r"https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.*)", url)
  57. if m2:
  58. org, repo, branch, path = m2.groups()
  59. return (
  60. f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path}"
  61. )
  62. # No match; return as-is
  63. return url
  64. @router.post("/load/url", response_model=Optional[dict])
  65. async def load_function_from_url(
  66. request: Request, form_data: LoadUrlForm, user=Depends(get_admin_user)
  67. ):
  68. # NOTE: This is NOT a SSRF vulnerability:
  69. # This endpoint is admin-only (see get_admin_user), meant for *trusted* internal use,
  70. # and does NOT accept untrusted user input. Access is enforced by authentication.
  71. url = str(form_data.url)
  72. if not url:
  73. raise HTTPException(status_code=400, detail="Please enter a valid URL")
  74. url = github_url_to_raw_url(url)
  75. url_parts = url.rstrip("/").split("/")
  76. file_name = url_parts[-1]
  77. function_name = (
  78. file_name[:-3]
  79. if (
  80. file_name.endswith(".py")
  81. and (not file_name.startswith(("main.py", "index.py", "__init__.py")))
  82. )
  83. else url_parts[-2] if len(url_parts) > 1 else "function"
  84. )
  85. try:
  86. async with aiohttp.ClientSession(trust_env=True) as session:
  87. async with session.get(
  88. url, headers={"Content-Type": "application/json"}
  89. ) as resp:
  90. if resp.status != 200:
  91. raise HTTPException(
  92. status_code=resp.status, detail="Failed to fetch the function"
  93. )
  94. data = await resp.text()
  95. if not data:
  96. raise HTTPException(
  97. status_code=400, detail="No data received from the URL"
  98. )
  99. return {
  100. "name": function_name,
  101. "content": data,
  102. }
  103. except Exception as e:
  104. raise HTTPException(status_code=500, detail=f"Error importing function: {e}")
  105. ############################
  106. # SyncFunctions
  107. ############################
  108. class SyncFunctionsForm(BaseModel):
  109. functions: list[FunctionWithValvesModel] = []
  110. @router.post("/sync", response_model=list[FunctionWithValvesModel])
  111. async def sync_functions(
  112. request: Request, form_data: SyncFunctionsForm, user=Depends(get_admin_user)
  113. ):
  114. try:
  115. for function in form_data.functions:
  116. function.content = replace_imports(function.content)
  117. function_module, function_type, frontmatter = load_function_module_by_id(
  118. function.id,
  119. content=function.content,
  120. )
  121. if hasattr(function_module, "Valves") and function.valves:
  122. Valves = function_module.Valves
  123. try:
  124. Valves(
  125. **{k: v for k, v in function.valves.items() if v is not None}
  126. )
  127. except Exception as e:
  128. log.exception(
  129. f"Error validating valves for function {function.id}: {e}"
  130. )
  131. raise e
  132. return Functions.sync_functions(user.id, form_data.functions)
  133. except Exception as e:
  134. log.exception(f"Failed to load a function: {e}")
  135. raise HTTPException(
  136. status_code=status.HTTP_400_BAD_REQUEST,
  137. detail=ERROR_MESSAGES.DEFAULT(e),
  138. )
  139. ############################
  140. # CreateNewFunction
  141. ############################
  142. @router.post("/create", response_model=Optional[FunctionResponse])
  143. async def create_new_function(
  144. request: Request, form_data: FunctionForm, user=Depends(get_admin_user)
  145. ):
  146. if not form_data.id.isidentifier():
  147. raise HTTPException(
  148. status_code=status.HTTP_400_BAD_REQUEST,
  149. detail="Only alphanumeric characters and underscores are allowed in the id",
  150. )
  151. form_data.id = form_data.id.lower()
  152. function = Functions.get_function_by_id(form_data.id)
  153. if function is None:
  154. try:
  155. form_data.content = replace_imports(form_data.content)
  156. function_module, function_type, frontmatter = load_function_module_by_id(
  157. form_data.id,
  158. content=form_data.content,
  159. )
  160. form_data.meta.manifest = frontmatter
  161. FUNCTIONS = request.app.state.FUNCTIONS
  162. FUNCTIONS[form_data.id] = function_module
  163. function = Functions.insert_new_function(user.id, function_type, form_data)
  164. function_cache_dir = CACHE_DIR / "functions" / form_data.id
  165. function_cache_dir.mkdir(parents=True, exist_ok=True)
  166. if function_type == "filter" and getattr(function_module, "toggle", None):
  167. Functions.update_function_metadata_by_id(id, {"toggle": True})
  168. if function:
  169. return function
  170. else:
  171. raise HTTPException(
  172. status_code=status.HTTP_400_BAD_REQUEST,
  173. detail=ERROR_MESSAGES.DEFAULT("Error creating function"),
  174. )
  175. except Exception as e:
  176. log.exception(f"Failed to create a new function: {e}")
  177. raise HTTPException(
  178. status_code=status.HTTP_400_BAD_REQUEST,
  179. detail=ERROR_MESSAGES.DEFAULT(e),
  180. )
  181. else:
  182. raise HTTPException(
  183. status_code=status.HTTP_400_BAD_REQUEST,
  184. detail=ERROR_MESSAGES.ID_TAKEN,
  185. )
  186. ############################
  187. # GetFunctionById
  188. ############################
  189. @router.get("/id/{id}", response_model=Optional[FunctionModel])
  190. async def get_function_by_id(id: str, user=Depends(get_admin_user)):
  191. function = Functions.get_function_by_id(id)
  192. if function:
  193. return function
  194. else:
  195. raise HTTPException(
  196. status_code=status.HTTP_401_UNAUTHORIZED,
  197. detail=ERROR_MESSAGES.NOT_FOUND,
  198. )
  199. ############################
  200. # ToggleFunctionById
  201. ############################
  202. @router.post("/id/{id}/toggle", response_model=Optional[FunctionModel])
  203. async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
  204. function = Functions.get_function_by_id(id)
  205. if function:
  206. function = Functions.update_function_by_id(
  207. id, {"is_active": not function.is_active}
  208. )
  209. if function:
  210. return function
  211. else:
  212. raise HTTPException(
  213. status_code=status.HTTP_400_BAD_REQUEST,
  214. detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
  215. )
  216. else:
  217. raise HTTPException(
  218. status_code=status.HTTP_401_UNAUTHORIZED,
  219. detail=ERROR_MESSAGES.NOT_FOUND,
  220. )
  221. ############################
  222. # ToggleGlobalById
  223. ############################
  224. @router.post("/id/{id}/toggle/global", response_model=Optional[FunctionModel])
  225. async def toggle_global_by_id(id: str, user=Depends(get_admin_user)):
  226. function = Functions.get_function_by_id(id)
  227. if function:
  228. function = Functions.update_function_by_id(
  229. id, {"is_global": not function.is_global}
  230. )
  231. if function:
  232. return function
  233. else:
  234. raise HTTPException(
  235. status_code=status.HTTP_400_BAD_REQUEST,
  236. detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
  237. )
  238. else:
  239. raise HTTPException(
  240. status_code=status.HTTP_401_UNAUTHORIZED,
  241. detail=ERROR_MESSAGES.NOT_FOUND,
  242. )
  243. ############################
  244. # UpdateFunctionById
  245. ############################
  246. @router.post("/id/{id}/update", response_model=Optional[FunctionModel])
  247. async def update_function_by_id(
  248. request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user)
  249. ):
  250. try:
  251. form_data.content = replace_imports(form_data.content)
  252. function_module, function_type, frontmatter = load_function_module_by_id(
  253. id, content=form_data.content
  254. )
  255. form_data.meta.manifest = frontmatter
  256. FUNCTIONS = request.app.state.FUNCTIONS
  257. FUNCTIONS[id] = function_module
  258. updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
  259. log.debug(updated)
  260. function = Functions.update_function_by_id(id, updated)
  261. if function_type == "filter" and getattr(function_module, "toggle", None):
  262. Functions.update_function_metadata_by_id(id, {"toggle": True})
  263. if function:
  264. return function
  265. else:
  266. raise HTTPException(
  267. status_code=status.HTTP_400_BAD_REQUEST,
  268. detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
  269. )
  270. except Exception as e:
  271. raise HTTPException(
  272. status_code=status.HTTP_400_BAD_REQUEST,
  273. detail=ERROR_MESSAGES.DEFAULT(e),
  274. )
  275. ############################
  276. # DeleteFunctionById
  277. ############################
  278. @router.delete("/id/{id}/delete", response_model=bool)
  279. async def delete_function_by_id(
  280. request: Request, id: str, user=Depends(get_admin_user)
  281. ):
  282. result = Functions.delete_function_by_id(id)
  283. if result:
  284. FUNCTIONS = request.app.state.FUNCTIONS
  285. if id in FUNCTIONS:
  286. del FUNCTIONS[id]
  287. return result
  288. ############################
  289. # GetFunctionValves
  290. ############################
  291. @router.get("/id/{id}/valves", response_model=Optional[dict])
  292. async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)):
  293. function = Functions.get_function_by_id(id)
  294. if function:
  295. try:
  296. valves = Functions.get_function_valves_by_id(id)
  297. return valves
  298. except Exception as e:
  299. raise HTTPException(
  300. status_code=status.HTTP_400_BAD_REQUEST,
  301. detail=ERROR_MESSAGES.DEFAULT(e),
  302. )
  303. else:
  304. raise HTTPException(
  305. status_code=status.HTTP_401_UNAUTHORIZED,
  306. detail=ERROR_MESSAGES.NOT_FOUND,
  307. )
  308. ############################
  309. # GetFunctionValvesSpec
  310. ############################
  311. @router.get("/id/{id}/valves/spec", response_model=Optional[dict])
  312. async def get_function_valves_spec_by_id(
  313. request: Request, id: str, user=Depends(get_admin_user)
  314. ):
  315. function = Functions.get_function_by_id(id)
  316. if function:
  317. function_module, function_type, frontmatter = get_function_module_from_cache(
  318. request, id
  319. )
  320. if hasattr(function_module, "Valves"):
  321. Valves = function_module.Valves
  322. return Valves.schema()
  323. return None
  324. else:
  325. raise HTTPException(
  326. status_code=status.HTTP_401_UNAUTHORIZED,
  327. detail=ERROR_MESSAGES.NOT_FOUND,
  328. )
  329. ############################
  330. # UpdateFunctionValves
  331. ############################
  332. @router.post("/id/{id}/valves/update", response_model=Optional[dict])
  333. async def update_function_valves_by_id(
  334. request: Request, id: str, form_data: dict, user=Depends(get_admin_user)
  335. ):
  336. function = Functions.get_function_by_id(id)
  337. if function:
  338. function_module, function_type, frontmatter = get_function_module_from_cache(
  339. request, id
  340. )
  341. if hasattr(function_module, "Valves"):
  342. Valves = function_module.Valves
  343. try:
  344. form_data = {k: v for k, v in form_data.items() if v is not None}
  345. valves = Valves(**form_data)
  346. valves_dict = valves.model_dump(exclude_unset=True)
  347. Functions.update_function_valves_by_id(id, valves_dict)
  348. return valves_dict
  349. except Exception as e:
  350. log.exception(f"Error updating function values by id {id}: {e}")
  351. raise HTTPException(
  352. status_code=status.HTTP_400_BAD_REQUEST,
  353. detail=ERROR_MESSAGES.DEFAULT(e),
  354. )
  355. else:
  356. raise HTTPException(
  357. status_code=status.HTTP_401_UNAUTHORIZED,
  358. detail=ERROR_MESSAGES.NOT_FOUND,
  359. )
  360. else:
  361. raise HTTPException(
  362. status_code=status.HTTP_401_UNAUTHORIZED,
  363. detail=ERROR_MESSAGES.NOT_FOUND,
  364. )
  365. ############################
  366. # FunctionUserValves
  367. ############################
  368. @router.get("/id/{id}/valves/user", response_model=Optional[dict])
  369. async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user)):
  370. function = Functions.get_function_by_id(id)
  371. if function:
  372. try:
  373. user_valves = Functions.get_user_valves_by_id_and_user_id(id, user.id)
  374. return user_valves
  375. except Exception as e:
  376. raise HTTPException(
  377. status_code=status.HTTP_400_BAD_REQUEST,
  378. detail=ERROR_MESSAGES.DEFAULT(e),
  379. )
  380. else:
  381. raise HTTPException(
  382. status_code=status.HTTP_401_UNAUTHORIZED,
  383. detail=ERROR_MESSAGES.NOT_FOUND,
  384. )
  385. @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
  386. async def get_function_user_valves_spec_by_id(
  387. request: Request, id: str, user=Depends(get_verified_user)
  388. ):
  389. function = Functions.get_function_by_id(id)
  390. if function:
  391. function_module, function_type, frontmatter = get_function_module_from_cache(
  392. request, id
  393. )
  394. if hasattr(function_module, "UserValves"):
  395. UserValves = function_module.UserValves
  396. return UserValves.schema()
  397. return None
  398. else:
  399. raise HTTPException(
  400. status_code=status.HTTP_401_UNAUTHORIZED,
  401. detail=ERROR_MESSAGES.NOT_FOUND,
  402. )
  403. @router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
  404. async def update_function_user_valves_by_id(
  405. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  406. ):
  407. function = Functions.get_function_by_id(id)
  408. if function:
  409. function_module, function_type, frontmatter = get_function_module_from_cache(
  410. request, id
  411. )
  412. if hasattr(function_module, "UserValves"):
  413. UserValves = function_module.UserValves
  414. try:
  415. form_data = {k: v for k, v in form_data.items() if v is not None}
  416. user_valves = UserValves(**form_data)
  417. user_valves_dict = user_valves.model_dump(exclude_unset=True)
  418. Functions.update_user_valves_by_id_and_user_id(
  419. id, user.id, user_valves_dict
  420. )
  421. return user_valves_dict
  422. except Exception as e:
  423. log.exception(f"Error updating function user valves by id {id}: {e}")
  424. raise HTTPException(
  425. status_code=status.HTTP_400_BAD_REQUEST,
  426. detail=ERROR_MESSAGES.DEFAULT(e),
  427. )
  428. else:
  429. raise HTTPException(
  430. status_code=status.HTTP_401_UNAUTHORIZED,
  431. detail=ERROR_MESSAGES.NOT_FOUND,
  432. )
  433. else:
  434. raise HTTPException(
  435. status_code=status.HTTP_401_UNAUTHORIZED,
  436. detail=ERROR_MESSAGES.NOT_FOUND,
  437. )