pipelines.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504
  1. from fastapi import (
  2. Depends,
  3. FastAPI,
  4. File,
  5. Form,
  6. HTTPException,
  7. Request,
  8. UploadFile,
  9. status,
  10. APIRouter,
  11. )
  12. import aiohttp
  13. import os
  14. import logging
  15. import shutil
  16. import requests
  17. from pydantic import BaseModel
  18. from starlette.responses import FileResponse
  19. from typing import Optional
  20. from open_webui.env import SRC_LOG_LEVELS, AIOHTTP_CLIENT_SESSION_SSL
  21. from open_webui.config import CACHE_DIR
  22. from open_webui.constants import ERROR_MESSAGES
  23. from open_webui.routers.openai import get_all_models_responses
  24. from open_webui.utils.auth import get_admin_user
  25. log = logging.getLogger(__name__)
  26. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  27. ##################################
  28. #
  29. # Pipeline Middleware
  30. #
  31. ##################################
  32. def get_sorted_filters(model_id, models):
  33. filters = [
  34. model
  35. for model in models.values()
  36. if "pipeline" in model
  37. and "type" in model["pipeline"]
  38. and model["pipeline"]["type"] == "filter"
  39. and (
  40. model["pipeline"]["pipelines"] == ["*"]
  41. or any(
  42. model_id == target_model_id
  43. for target_model_id in model["pipeline"]["pipelines"]
  44. )
  45. )
  46. ]
  47. sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
  48. return sorted_filters
  49. async def process_pipeline_inlet_filter(request, payload, user, models):
  50. user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
  51. model_id = payload["model"]
  52. sorted_filters = get_sorted_filters(model_id, models)
  53. model = models[model_id]
  54. if "pipeline" in model:
  55. sorted_filters.append(model)
  56. async with aiohttp.ClientSession(trust_env=True) as session:
  57. for filter in sorted_filters:
  58. urlIdx = filter.get("urlIdx")
  59. if urlIdx is None:
  60. continue
  61. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  62. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  63. if not key:
  64. continue
  65. headers = {"Authorization": f"Bearer {key}"}
  66. request_data = {
  67. "user": user,
  68. "body": payload,
  69. }
  70. try:
  71. async with session.post(
  72. f"{url}/{filter['id']}/filter/inlet",
  73. headers=headers,
  74. json=request_data,
  75. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  76. ) as response:
  77. payload = await response.json()
  78. response.raise_for_status()
  79. except aiohttp.ClientResponseError as e:
  80. res = (
  81. await response.json()
  82. if response.content_type == "application/json"
  83. else {}
  84. )
  85. if "detail" in res:
  86. raise Exception(response.status, res["detail"])
  87. except Exception as e:
  88. log.exception(f"Connection error: {e}")
  89. return payload
  90. async def process_pipeline_outlet_filter(request, payload, user, models):
  91. user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
  92. model_id = payload["model"]
  93. sorted_filters = get_sorted_filters(model_id, models)
  94. model = models[model_id]
  95. if "pipeline" in model:
  96. sorted_filters = [model] + sorted_filters
  97. async with aiohttp.ClientSession(trust_env=True) as session:
  98. for filter in sorted_filters:
  99. urlIdx = filter.get("urlIdx")
  100. if urlIdx is None:
  101. continue
  102. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  103. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  104. if not key:
  105. continue
  106. headers = {"Authorization": f"Bearer {key}"}
  107. request_data = {
  108. "user": user,
  109. "body": payload,
  110. }
  111. try:
  112. async with session.post(
  113. f"{url}/{filter['id']}/filter/outlet",
  114. headers=headers,
  115. json=request_data,
  116. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  117. ) as response:
  118. payload = await response.json()
  119. response.raise_for_status()
  120. except aiohttp.ClientResponseError as e:
  121. try:
  122. res = (
  123. await response.json()
  124. if "application/json" in response.content_type
  125. else {}
  126. )
  127. if "detail" in res:
  128. raise Exception(response.status, res)
  129. except Exception:
  130. pass
  131. except Exception as e:
  132. log.exception(f"Connection error: {e}")
  133. return payload
  134. ##################################
  135. #
  136. # Pipelines Endpoints
  137. #
  138. ##################################
  139. router = APIRouter()
  140. @router.get("/list")
  141. async def get_pipelines_list(request: Request, user=Depends(get_admin_user)):
  142. responses = await get_all_models_responses(request, user)
  143. log.debug(f"get_pipelines_list: get_openai_models_responses returned {responses}")
  144. urlIdxs = [
  145. idx
  146. for idx, response in enumerate(responses)
  147. if response is not None and "pipelines" in response
  148. ]
  149. return {
  150. "data": [
  151. {
  152. "url": request.app.state.config.OPENAI_API_BASE_URLS[urlIdx],
  153. "idx": urlIdx,
  154. }
  155. for urlIdx in urlIdxs
  156. ]
  157. }
  158. @router.post("/upload")
  159. async def upload_pipeline(
  160. request: Request,
  161. urlIdx: int = Form(...),
  162. file: UploadFile = File(...),
  163. user=Depends(get_admin_user),
  164. ):
  165. log.info(f"upload_pipeline: urlIdx={urlIdx}, filename={file.filename}")
  166. filename = os.path.basename(file.filename)
  167. # Check if the uploaded file is a python file
  168. if not (filename and filename.endswith(".py")):
  169. raise HTTPException(
  170. status_code=status.HTTP_400_BAD_REQUEST,
  171. detail="Only Python (.py) files are allowed.",
  172. )
  173. upload_folder = f"{CACHE_DIR}/pipelines"
  174. os.makedirs(upload_folder, exist_ok=True)
  175. file_path = os.path.join(upload_folder, filename)
  176. r = None
  177. try:
  178. # Save the uploaded file
  179. with open(file_path, "wb") as buffer:
  180. shutil.copyfileobj(file.file, buffer)
  181. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  182. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  183. with open(file_path, "rb") as f:
  184. files = {"file": f}
  185. r = requests.post(
  186. f"{url}/pipelines/upload",
  187. headers={"Authorization": f"Bearer {key}"},
  188. files=files,
  189. )
  190. r.raise_for_status()
  191. data = r.json()
  192. return {**data}
  193. except Exception as e:
  194. # Handle connection error here
  195. log.exception(f"Connection error: {e}")
  196. detail = None
  197. status_code = status.HTTP_404_NOT_FOUND
  198. if r is not None:
  199. status_code = r.status_code
  200. try:
  201. res = r.json()
  202. if "detail" in res:
  203. detail = res["detail"]
  204. except Exception:
  205. pass
  206. raise HTTPException(
  207. status_code=status_code,
  208. detail=detail if detail else "Pipeline not found",
  209. )
  210. finally:
  211. # Ensure the file is deleted after the upload is completed or on failure
  212. if os.path.exists(file_path):
  213. os.remove(file_path)
  214. class AddPipelineForm(BaseModel):
  215. url: str
  216. urlIdx: int
  217. @router.post("/add")
  218. async def add_pipeline(
  219. request: Request, form_data: AddPipelineForm, user=Depends(get_admin_user)
  220. ):
  221. r = None
  222. try:
  223. urlIdx = form_data.urlIdx
  224. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  225. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  226. r = requests.post(
  227. f"{url}/pipelines/add",
  228. headers={"Authorization": f"Bearer {key}"},
  229. json={"url": form_data.url},
  230. )
  231. r.raise_for_status()
  232. data = r.json()
  233. return {**data}
  234. except Exception as e:
  235. # Handle connection error here
  236. log.exception(f"Connection error: {e}")
  237. detail = None
  238. if r is not None:
  239. try:
  240. res = r.json()
  241. if "detail" in res:
  242. detail = res["detail"]
  243. except Exception:
  244. pass
  245. raise HTTPException(
  246. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  247. detail=detail if detail else "Pipeline not found",
  248. )
  249. class DeletePipelineForm(BaseModel):
  250. id: str
  251. urlIdx: int
  252. @router.delete("/delete")
  253. async def delete_pipeline(
  254. request: Request, form_data: DeletePipelineForm, user=Depends(get_admin_user)
  255. ):
  256. r = None
  257. try:
  258. urlIdx = form_data.urlIdx
  259. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  260. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  261. r = requests.delete(
  262. f"{url}/pipelines/delete",
  263. headers={"Authorization": f"Bearer {key}"},
  264. json={"id": form_data.id},
  265. )
  266. r.raise_for_status()
  267. data = r.json()
  268. return {**data}
  269. except Exception as e:
  270. # Handle connection error here
  271. log.exception(f"Connection error: {e}")
  272. detail = None
  273. if r is not None:
  274. try:
  275. res = r.json()
  276. if "detail" in res:
  277. detail = res["detail"]
  278. except Exception:
  279. pass
  280. raise HTTPException(
  281. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  282. detail=detail if detail else "Pipeline not found",
  283. )
  284. @router.get("/")
  285. async def get_pipelines(
  286. request: Request, urlIdx: Optional[int] = None, user=Depends(get_admin_user)
  287. ):
  288. r = None
  289. try:
  290. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  291. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  292. r = requests.get(f"{url}/pipelines", headers={"Authorization": f"Bearer {key}"})
  293. r.raise_for_status()
  294. data = r.json()
  295. return {**data}
  296. except Exception as e:
  297. # Handle connection error here
  298. log.exception(f"Connection error: {e}")
  299. detail = None
  300. if r is not None:
  301. try:
  302. res = r.json()
  303. if "detail" in res:
  304. detail = res["detail"]
  305. except Exception:
  306. pass
  307. raise HTTPException(
  308. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  309. detail=detail if detail else "Pipeline not found",
  310. )
  311. @router.get("/{pipeline_id}/valves")
  312. async def get_pipeline_valves(
  313. request: Request,
  314. urlIdx: Optional[int],
  315. pipeline_id: str,
  316. user=Depends(get_admin_user),
  317. ):
  318. r = None
  319. try:
  320. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  321. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  322. r = requests.get(
  323. f"{url}/{pipeline_id}/valves", headers={"Authorization": f"Bearer {key}"}
  324. )
  325. r.raise_for_status()
  326. data = r.json()
  327. return {**data}
  328. except Exception as e:
  329. # Handle connection error here
  330. log.exception(f"Connection error: {e}")
  331. detail = None
  332. if r is not None:
  333. try:
  334. res = r.json()
  335. if "detail" in res:
  336. detail = res["detail"]
  337. except Exception:
  338. pass
  339. raise HTTPException(
  340. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  341. detail=detail if detail else "Pipeline not found",
  342. )
  343. @router.get("/{pipeline_id}/valves/spec")
  344. async def get_pipeline_valves_spec(
  345. request: Request,
  346. urlIdx: Optional[int],
  347. pipeline_id: str,
  348. user=Depends(get_admin_user),
  349. ):
  350. r = None
  351. try:
  352. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  353. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  354. r = requests.get(
  355. f"{url}/{pipeline_id}/valves/spec",
  356. headers={"Authorization": f"Bearer {key}"},
  357. )
  358. r.raise_for_status()
  359. data = r.json()
  360. return {**data}
  361. except Exception as e:
  362. # Handle connection error here
  363. log.exception(f"Connection error: {e}")
  364. detail = None
  365. if r is not None:
  366. try:
  367. res = r.json()
  368. if "detail" in res:
  369. detail = res["detail"]
  370. except Exception:
  371. pass
  372. raise HTTPException(
  373. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  374. detail=detail if detail else "Pipeline not found",
  375. )
  376. @router.post("/{pipeline_id}/valves/update")
  377. async def update_pipeline_valves(
  378. request: Request,
  379. urlIdx: Optional[int],
  380. pipeline_id: str,
  381. form_data: dict,
  382. user=Depends(get_admin_user),
  383. ):
  384. r = None
  385. try:
  386. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  387. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  388. r = requests.post(
  389. f"{url}/{pipeline_id}/valves/update",
  390. headers={"Authorization": f"Bearer {key}"},
  391. json={**form_data},
  392. )
  393. r.raise_for_status()
  394. data = r.json()
  395. return {**data}
  396. except Exception as e:
  397. # Handle connection error here
  398. log.exception(f"Connection error: {e}")
  399. detail = None
  400. if r is not None:
  401. try:
  402. res = r.json()
  403. if "detail" in res:
  404. detail = res["detail"]
  405. except Exception:
  406. pass
  407. raise HTTPException(
  408. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  409. detail=detail if detail else "Pipeline not found",
  410. )