pipelines.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510
  1. from fastapi import (
  2. Depends,
  3. FastAPI,
  4. File,
  5. Form,
  6. HTTPException,
  7. Request,
  8. UploadFile,
  9. status,
  10. APIRouter,
  11. )
  12. import aiohttp
  13. import os
  14. import logging
  15. import shutil
  16. import requests
  17. from pydantic import BaseModel
  18. from starlette.responses import FileResponse
  19. from typing import Optional
  20. from open_webui.env import SRC_LOG_LEVELS, AIOHTTP_CLIENT_SESSION_SSL
  21. from open_webui.config import CACHE_DIR
  22. from open_webui.constants import ERROR_MESSAGES
  23. from open_webui.routers.openai import get_all_models_responses
  24. from open_webui.utils.auth import get_admin_user
  25. log = logging.getLogger(__name__)
  26. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  27. ##################################
  28. #
  29. # Pipeline Middleware
  30. #
  31. ##################################
  32. def get_sorted_filters(model_id, models):
  33. filters = [
  34. model
  35. for model in models.values()
  36. if "pipeline" in model
  37. and "type" in model["pipeline"]
  38. and model["pipeline"]["type"] == "filter"
  39. and (
  40. model["pipeline"]["pipelines"] == ["*"]
  41. or any(
  42. model_id == target_model_id
  43. for target_model_id in model["pipeline"]["pipelines"]
  44. )
  45. )
  46. ]
  47. sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
  48. return sorted_filters
  49. async def process_pipeline_inlet_filter(request, payload, user, models):
  50. user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
  51. model_id = payload["model"]
  52. sorted_filters = get_sorted_filters(model_id, models)
  53. model = models[model_id]
  54. if "pipeline" in model:
  55. sorted_filters.append(model)
  56. async with aiohttp.ClientSession(trust_env=True) as session:
  57. for filter in sorted_filters:
  58. urlIdx = filter.get("urlIdx")
  59. try:
  60. urlIdx = int(urlIdx)
  61. except:
  62. continue
  63. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  64. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  65. if not key:
  66. continue
  67. headers = {"Authorization": f"Bearer {key}"}
  68. request_data = {
  69. "user": user,
  70. "body": payload,
  71. }
  72. try:
  73. async with session.post(
  74. f"{url}/{filter['id']}/filter/inlet",
  75. headers=headers,
  76. json=request_data,
  77. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  78. ) as response:
  79. payload = await response.json()
  80. response.raise_for_status()
  81. except aiohttp.ClientResponseError as e:
  82. res = (
  83. await response.json()
  84. if response.content_type == "application/json"
  85. else {}
  86. )
  87. if "detail" in res:
  88. raise Exception(response.status, res["detail"])
  89. except Exception as e:
  90. log.exception(f"Connection error: {e}")
  91. return payload
  92. async def process_pipeline_outlet_filter(request, payload, user, models):
  93. user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
  94. model_id = payload["model"]
  95. sorted_filters = get_sorted_filters(model_id, models)
  96. model = models[model_id]
  97. if "pipeline" in model:
  98. sorted_filters = [model] + sorted_filters
  99. async with aiohttp.ClientSession(trust_env=True) as session:
  100. for filter in sorted_filters:
  101. urlIdx = filter.get("urlIdx")
  102. try:
  103. urlIdx = int(urlIdx)
  104. except:
  105. continue
  106. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  107. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  108. if not key:
  109. continue
  110. headers = {"Authorization": f"Bearer {key}"}
  111. request_data = {
  112. "user": user,
  113. "body": payload,
  114. }
  115. try:
  116. async with session.post(
  117. f"{url}/{filter['id']}/filter/outlet",
  118. headers=headers,
  119. json=request_data,
  120. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  121. ) as response:
  122. payload = await response.json()
  123. response.raise_for_status()
  124. except aiohttp.ClientResponseError as e:
  125. try:
  126. res = (
  127. await response.json()
  128. if "application/json" in response.content_type
  129. else {}
  130. )
  131. if "detail" in res:
  132. raise Exception(response.status, res)
  133. except Exception:
  134. pass
  135. except Exception as e:
  136. log.exception(f"Connection error: {e}")
  137. return payload
  138. ##################################
  139. #
  140. # Pipelines Endpoints
  141. #
  142. ##################################
  143. router = APIRouter()
  144. @router.get("/list")
  145. async def get_pipelines_list(request: Request, user=Depends(get_admin_user)):
  146. responses = await get_all_models_responses(request, user)
  147. log.debug(f"get_pipelines_list: get_openai_models_responses returned {responses}")
  148. urlIdxs = [
  149. idx
  150. for idx, response in enumerate(responses)
  151. if response is not None and "pipelines" in response
  152. ]
  153. return {
  154. "data": [
  155. {
  156. "url": request.app.state.config.OPENAI_API_BASE_URLS[urlIdx],
  157. "idx": urlIdx,
  158. }
  159. for urlIdx in urlIdxs
  160. ]
  161. }
  162. @router.post("/upload")
  163. async def upload_pipeline(
  164. request: Request,
  165. urlIdx: int = Form(...),
  166. file: UploadFile = File(...),
  167. user=Depends(get_admin_user),
  168. ):
  169. log.info(f"upload_pipeline: urlIdx={urlIdx}, filename={file.filename}")
  170. filename = os.path.basename(file.filename)
  171. # Check if the uploaded file is a python file
  172. if not (filename and filename.endswith(".py")):
  173. raise HTTPException(
  174. status_code=status.HTTP_400_BAD_REQUEST,
  175. detail="Only Python (.py) files are allowed.",
  176. )
  177. upload_folder = f"{CACHE_DIR}/pipelines"
  178. os.makedirs(upload_folder, exist_ok=True)
  179. file_path = os.path.join(upload_folder, filename)
  180. r = None
  181. try:
  182. # Save the uploaded file
  183. with open(file_path, "wb") as buffer:
  184. shutil.copyfileobj(file.file, buffer)
  185. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  186. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  187. with open(file_path, "rb") as f:
  188. files = {"file": f}
  189. r = requests.post(
  190. f"{url}/pipelines/upload",
  191. headers={"Authorization": f"Bearer {key}"},
  192. files=files,
  193. )
  194. r.raise_for_status()
  195. data = r.json()
  196. return {**data}
  197. except Exception as e:
  198. # Handle connection error here
  199. log.exception(f"Connection error: {e}")
  200. detail = None
  201. status_code = status.HTTP_404_NOT_FOUND
  202. if r is not None:
  203. status_code = r.status_code
  204. try:
  205. res = r.json()
  206. if "detail" in res:
  207. detail = res["detail"]
  208. except Exception:
  209. pass
  210. raise HTTPException(
  211. status_code=status_code,
  212. detail=detail if detail else "Pipeline not found",
  213. )
  214. finally:
  215. # Ensure the file is deleted after the upload is completed or on failure
  216. if os.path.exists(file_path):
  217. os.remove(file_path)
  218. class AddPipelineForm(BaseModel):
  219. url: str
  220. urlIdx: int
  221. @router.post("/add")
  222. async def add_pipeline(
  223. request: Request, form_data: AddPipelineForm, user=Depends(get_admin_user)
  224. ):
  225. r = None
  226. try:
  227. urlIdx = form_data.urlIdx
  228. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  229. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  230. r = requests.post(
  231. f"{url}/pipelines/add",
  232. headers={"Authorization": f"Bearer {key}"},
  233. json={"url": form_data.url},
  234. )
  235. r.raise_for_status()
  236. data = r.json()
  237. return {**data}
  238. except Exception as e:
  239. # Handle connection error here
  240. log.exception(f"Connection error: {e}")
  241. detail = None
  242. if r is not None:
  243. try:
  244. res = r.json()
  245. if "detail" in res:
  246. detail = res["detail"]
  247. except Exception:
  248. pass
  249. raise HTTPException(
  250. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  251. detail=detail if detail else "Pipeline not found",
  252. )
  253. class DeletePipelineForm(BaseModel):
  254. id: str
  255. urlIdx: int
  256. @router.delete("/delete")
  257. async def delete_pipeline(
  258. request: Request, form_data: DeletePipelineForm, user=Depends(get_admin_user)
  259. ):
  260. r = None
  261. try:
  262. urlIdx = form_data.urlIdx
  263. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  264. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  265. r = requests.delete(
  266. f"{url}/pipelines/delete",
  267. headers={"Authorization": f"Bearer {key}"},
  268. json={"id": form_data.id},
  269. )
  270. r.raise_for_status()
  271. data = r.json()
  272. return {**data}
  273. except Exception as e:
  274. # Handle connection error here
  275. log.exception(f"Connection error: {e}")
  276. detail = None
  277. if r is not None:
  278. try:
  279. res = r.json()
  280. if "detail" in res:
  281. detail = res["detail"]
  282. except Exception:
  283. pass
  284. raise HTTPException(
  285. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  286. detail=detail if detail else "Pipeline not found",
  287. )
  288. @router.get("/")
  289. async def get_pipelines(
  290. request: Request, urlIdx: Optional[int] = None, user=Depends(get_admin_user)
  291. ):
  292. r = None
  293. try:
  294. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  295. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  296. r = requests.get(f"{url}/pipelines", headers={"Authorization": f"Bearer {key}"})
  297. r.raise_for_status()
  298. data = r.json()
  299. return {**data}
  300. except Exception as e:
  301. # Handle connection error here
  302. log.exception(f"Connection error: {e}")
  303. detail = None
  304. if r is not None:
  305. try:
  306. res = r.json()
  307. if "detail" in res:
  308. detail = res["detail"]
  309. except Exception:
  310. pass
  311. raise HTTPException(
  312. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  313. detail=detail if detail else "Pipeline not found",
  314. )
  315. @router.get("/{pipeline_id}/valves")
  316. async def get_pipeline_valves(
  317. request: Request,
  318. urlIdx: Optional[int],
  319. pipeline_id: str,
  320. user=Depends(get_admin_user),
  321. ):
  322. r = None
  323. try:
  324. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  325. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  326. r = requests.get(
  327. f"{url}/{pipeline_id}/valves", headers={"Authorization": f"Bearer {key}"}
  328. )
  329. r.raise_for_status()
  330. data = r.json()
  331. return {**data}
  332. except Exception as e:
  333. # Handle connection error here
  334. log.exception(f"Connection error: {e}")
  335. detail = None
  336. if r is not None:
  337. try:
  338. res = r.json()
  339. if "detail" in res:
  340. detail = res["detail"]
  341. except Exception:
  342. pass
  343. raise HTTPException(
  344. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  345. detail=detail if detail else "Pipeline not found",
  346. )
  347. @router.get("/{pipeline_id}/valves/spec")
  348. async def get_pipeline_valves_spec(
  349. request: Request,
  350. urlIdx: Optional[int],
  351. pipeline_id: str,
  352. user=Depends(get_admin_user),
  353. ):
  354. r = None
  355. try:
  356. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  357. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  358. r = requests.get(
  359. f"{url}/{pipeline_id}/valves/spec",
  360. headers={"Authorization": f"Bearer {key}"},
  361. )
  362. r.raise_for_status()
  363. data = r.json()
  364. return {**data}
  365. except Exception as e:
  366. # Handle connection error here
  367. log.exception(f"Connection error: {e}")
  368. detail = None
  369. if r is not None:
  370. try:
  371. res = r.json()
  372. if "detail" in res:
  373. detail = res["detail"]
  374. except Exception:
  375. pass
  376. raise HTTPException(
  377. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  378. detail=detail if detail else "Pipeline not found",
  379. )
  380. @router.post("/{pipeline_id}/valves/update")
  381. async def update_pipeline_valves(
  382. request: Request,
  383. urlIdx: Optional[int],
  384. pipeline_id: str,
  385. form_data: dict,
  386. user=Depends(get_admin_user),
  387. ):
  388. r = None
  389. try:
  390. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  391. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  392. r = requests.post(
  393. f"{url}/{pipeline_id}/valves/update",
  394. headers={"Authorization": f"Bearer {key}"},
  395. json={**form_data},
  396. )
  397. r.raise_for_status()
  398. data = r.json()
  399. return {**data}
  400. except Exception as e:
  401. # Handle connection error here
  402. log.exception(f"Connection error: {e}")
  403. detail = None
  404. if r is not None:
  405. try:
  406. res = r.json()
  407. if "detail" in res:
  408. detail = res["detail"]
  409. except Exception:
  410. pass
  411. raise HTTPException(
  412. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  413. detail=detail if detail else "Pipeline not found",
  414. )