pipelines.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502
  1. from fastapi import (
  2. Depends,
  3. FastAPI,
  4. File,
  5. Form,
  6. HTTPException,
  7. Request,
  8. UploadFile,
  9. status,
  10. APIRouter,
  11. )
  12. import aiohttp
  13. import os
  14. import logging
  15. import shutil
  16. import requests
  17. from pydantic import BaseModel
  18. from starlette.responses import FileResponse
  19. from typing import Optional
  20. from open_webui.env import SRC_LOG_LEVELS
  21. from open_webui.config import CACHE_DIR
  22. from open_webui.constants import ERROR_MESSAGES
  23. from open_webui.routers.openai import get_all_models_responses
  24. from open_webui.utils.auth import get_admin_user
  25. log = logging.getLogger(__name__)
  26. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  27. ##################################
  28. #
  29. # Pipeline Middleware
  30. #
  31. ##################################
  32. def get_sorted_filters(model_id, models):
  33. filters = [
  34. model
  35. for model in models.values()
  36. if "pipeline" in model
  37. and "type" in model["pipeline"]
  38. and model["pipeline"]["type"] == "filter"
  39. and (
  40. model["pipeline"]["pipelines"] == ["*"]
  41. or any(
  42. model_id == target_model_id
  43. for target_model_id in model["pipeline"]["pipelines"]
  44. )
  45. )
  46. ]
  47. sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
  48. return sorted_filters
  49. async def process_pipeline_inlet_filter(request, payload, user, models):
  50. user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
  51. model_id = payload["model"]
  52. sorted_filters = get_sorted_filters(model_id, models)
  53. model = models[model_id]
  54. if "pipeline" in model:
  55. sorted_filters.append(model)
  56. async with aiohttp.ClientSession(trust_env=True) as session:
  57. for filter in sorted_filters:
  58. urlIdx = filter.get("urlIdx")
  59. if urlIdx is None:
  60. continue
  61. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  62. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  63. if not key:
  64. continue
  65. headers = {"Authorization": f"Bearer {key}"}
  66. request_data = {
  67. "user": user,
  68. "body": payload,
  69. }
  70. try:
  71. async with session.post(
  72. f"{url}/{filter['id']}/filter/inlet",
  73. headers=headers,
  74. json=request_data,
  75. ) as response:
  76. payload = await response.json()
  77. response.raise_for_status()
  78. except aiohttp.ClientResponseError as e:
  79. res = (
  80. await response.json()
  81. if response.content_type == "application/json"
  82. else {}
  83. )
  84. if "detail" in res:
  85. raise Exception(response.status, res["detail"])
  86. except Exception as e:
  87. log.exception(f"Connection error: {e}")
  88. return payload
  89. async def process_pipeline_outlet_filter(request, payload, user, models):
  90. user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
  91. model_id = payload["model"]
  92. sorted_filters = get_sorted_filters(model_id, models)
  93. model = models[model_id]
  94. if "pipeline" in model:
  95. sorted_filters = [model] + sorted_filters
  96. async with aiohttp.ClientSession(trust_env=True) as session:
  97. for filter in sorted_filters:
  98. urlIdx = filter.get("urlIdx")
  99. if urlIdx is None:
  100. continue
  101. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  102. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  103. if not key:
  104. continue
  105. headers = {"Authorization": f"Bearer {key}"}
  106. request_data = {
  107. "user": user,
  108. "body": payload,
  109. }
  110. try:
  111. async with session.post(
  112. f"{url}/{filter['id']}/filter/outlet",
  113. headers=headers,
  114. json=request_data,
  115. ) as response:
  116. payload = await response.json()
  117. response.raise_for_status()
  118. except aiohttp.ClientResponseError as e:
  119. try:
  120. res = (
  121. await response.json()
  122. if "application/json" in response.content_type
  123. else {}
  124. )
  125. if "detail" in res:
  126. raise Exception(response.status, res)
  127. except Exception:
  128. pass
  129. except Exception as e:
  130. log.exception(f"Connection error: {e}")
  131. return payload
  132. ##################################
  133. #
  134. # Pipelines Endpoints
  135. #
  136. ##################################
  137. router = APIRouter()
  138. @router.get("/list")
  139. async def get_pipelines_list(request: Request, user=Depends(get_admin_user)):
  140. responses = await get_all_models_responses(request, user)
  141. log.debug(f"get_pipelines_list: get_openai_models_responses returned {responses}")
  142. urlIdxs = [
  143. idx
  144. for idx, response in enumerate(responses)
  145. if response is not None and "pipelines" in response
  146. ]
  147. return {
  148. "data": [
  149. {
  150. "url": request.app.state.config.OPENAI_API_BASE_URLS[urlIdx],
  151. "idx": urlIdx,
  152. }
  153. for urlIdx in urlIdxs
  154. ]
  155. }
  156. @router.post("/upload")
  157. async def upload_pipeline(
  158. request: Request,
  159. urlIdx: int = Form(...),
  160. file: UploadFile = File(...),
  161. user=Depends(get_admin_user),
  162. ):
  163. log.info(f"upload_pipeline: urlIdx={urlIdx}, filename={file.filename}")
  164. filename = os.path.basename(file.filename)
  165. # Check if the uploaded file is a python file
  166. if not (filename and filename.endswith(".py")):
  167. raise HTTPException(
  168. status_code=status.HTTP_400_BAD_REQUEST,
  169. detail="Only Python (.py) files are allowed.",
  170. )
  171. upload_folder = f"{CACHE_DIR}/pipelines"
  172. os.makedirs(upload_folder, exist_ok=True)
  173. file_path = os.path.join(upload_folder, filename)
  174. r = None
  175. try:
  176. # Save the uploaded file
  177. with open(file_path, "wb") as buffer:
  178. shutil.copyfileobj(file.file, buffer)
  179. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  180. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  181. with open(file_path, "rb") as f:
  182. files = {"file": f}
  183. r = requests.post(
  184. f"{url}/pipelines/upload",
  185. headers={"Authorization": f"Bearer {key}"},
  186. files=files,
  187. )
  188. r.raise_for_status()
  189. data = r.json()
  190. return {**data}
  191. except Exception as e:
  192. # Handle connection error here
  193. log.exception(f"Connection error: {e}")
  194. detail = None
  195. status_code = status.HTTP_404_NOT_FOUND
  196. if r is not None:
  197. status_code = r.status_code
  198. try:
  199. res = r.json()
  200. if "detail" in res:
  201. detail = res["detail"]
  202. except Exception:
  203. pass
  204. raise HTTPException(
  205. status_code=status_code,
  206. detail=detail if detail else "Pipeline not found",
  207. )
  208. finally:
  209. # Ensure the file is deleted after the upload is completed or on failure
  210. if os.path.exists(file_path):
  211. os.remove(file_path)
  212. class AddPipelineForm(BaseModel):
  213. url: str
  214. urlIdx: int
  215. @router.post("/add")
  216. async def add_pipeline(
  217. request: Request, form_data: AddPipelineForm, user=Depends(get_admin_user)
  218. ):
  219. r = None
  220. try:
  221. urlIdx = form_data.urlIdx
  222. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  223. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  224. r = requests.post(
  225. f"{url}/pipelines/add",
  226. headers={"Authorization": f"Bearer {key}"},
  227. json={"url": form_data.url},
  228. )
  229. r.raise_for_status()
  230. data = r.json()
  231. return {**data}
  232. except Exception as e:
  233. # Handle connection error here
  234. log.exception(f"Connection error: {e}")
  235. detail = None
  236. if r is not None:
  237. try:
  238. res = r.json()
  239. if "detail" in res:
  240. detail = res["detail"]
  241. except Exception:
  242. pass
  243. raise HTTPException(
  244. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  245. detail=detail if detail else "Pipeline not found",
  246. )
  247. class DeletePipelineForm(BaseModel):
  248. id: str
  249. urlIdx: int
  250. @router.delete("/delete")
  251. async def delete_pipeline(
  252. request: Request, form_data: DeletePipelineForm, user=Depends(get_admin_user)
  253. ):
  254. r = None
  255. try:
  256. urlIdx = form_data.urlIdx
  257. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  258. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  259. r = requests.delete(
  260. f"{url}/pipelines/delete",
  261. headers={"Authorization": f"Bearer {key}"},
  262. json={"id": form_data.id},
  263. )
  264. r.raise_for_status()
  265. data = r.json()
  266. return {**data}
  267. except Exception as e:
  268. # Handle connection error here
  269. log.exception(f"Connection error: {e}")
  270. detail = None
  271. if r is not None:
  272. try:
  273. res = r.json()
  274. if "detail" in res:
  275. detail = res["detail"]
  276. except Exception:
  277. pass
  278. raise HTTPException(
  279. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  280. detail=detail if detail else "Pipeline not found",
  281. )
  282. @router.get("/")
  283. async def get_pipelines(
  284. request: Request, urlIdx: Optional[int] = None, user=Depends(get_admin_user)
  285. ):
  286. r = None
  287. try:
  288. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  289. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  290. r = requests.get(f"{url}/pipelines", headers={"Authorization": f"Bearer {key}"})
  291. r.raise_for_status()
  292. data = r.json()
  293. return {**data}
  294. except Exception as e:
  295. # Handle connection error here
  296. log.exception(f"Connection error: {e}")
  297. detail = None
  298. if r is not None:
  299. try:
  300. res = r.json()
  301. if "detail" in res:
  302. detail = res["detail"]
  303. except Exception:
  304. pass
  305. raise HTTPException(
  306. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  307. detail=detail if detail else "Pipeline not found",
  308. )
  309. @router.get("/{pipeline_id}/valves")
  310. async def get_pipeline_valves(
  311. request: Request,
  312. urlIdx: Optional[int],
  313. pipeline_id: str,
  314. user=Depends(get_admin_user),
  315. ):
  316. r = None
  317. try:
  318. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  319. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  320. r = requests.get(
  321. f"{url}/{pipeline_id}/valves", headers={"Authorization": f"Bearer {key}"}
  322. )
  323. r.raise_for_status()
  324. data = r.json()
  325. return {**data}
  326. except Exception as e:
  327. # Handle connection error here
  328. log.exception(f"Connection error: {e}")
  329. detail = None
  330. if r is not None:
  331. try:
  332. res = r.json()
  333. if "detail" in res:
  334. detail = res["detail"]
  335. except Exception:
  336. pass
  337. raise HTTPException(
  338. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  339. detail=detail if detail else "Pipeline not found",
  340. )
  341. @router.get("/{pipeline_id}/valves/spec")
  342. async def get_pipeline_valves_spec(
  343. request: Request,
  344. urlIdx: Optional[int],
  345. pipeline_id: str,
  346. user=Depends(get_admin_user),
  347. ):
  348. r = None
  349. try:
  350. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  351. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  352. r = requests.get(
  353. f"{url}/{pipeline_id}/valves/spec",
  354. headers={"Authorization": f"Bearer {key}"},
  355. )
  356. r.raise_for_status()
  357. data = r.json()
  358. return {**data}
  359. except Exception as e:
  360. # Handle connection error here
  361. log.exception(f"Connection error: {e}")
  362. detail = None
  363. if r is not None:
  364. try:
  365. res = r.json()
  366. if "detail" in res:
  367. detail = res["detail"]
  368. except Exception:
  369. pass
  370. raise HTTPException(
  371. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  372. detail=detail if detail else "Pipeline not found",
  373. )
  374. @router.post("/{pipeline_id}/valves/update")
  375. async def update_pipeline_valves(
  376. request: Request,
  377. urlIdx: Optional[int],
  378. pipeline_id: str,
  379. form_data: dict,
  380. user=Depends(get_admin_user),
  381. ):
  382. r = None
  383. try:
  384. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  385. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  386. r = requests.post(
  387. f"{url}/{pipeline_id}/valves/update",
  388. headers={"Authorization": f"Bearer {key}"},
  389. json={**form_data},
  390. )
  391. r.raise_for_status()
  392. data = r.json()
  393. return {**data}
  394. except Exception as e:
  395. # Handle connection error here
  396. log.exception(f"Connection error: {e}")
  397. detail = None
  398. if r is not None:
  399. try:
  400. res = r.json()
  401. if "detail" in res:
  402. detail = res["detail"]
  403. except Exception:
  404. pass
  405. raise HTTPException(
  406. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  407. detail=detail if detail else "Pipeline not found",
  408. )