openai.py 36 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076
  1. import asyncio
  2. import hashlib
  3. import json
  4. import logging
  5. from pathlib import Path
  6. from typing import Literal, Optional, overload
  7. import aiohttp
  8. from aiocache import cached
  9. import requests
  10. from urllib.parse import quote
  11. from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter
  12. from fastapi.middleware.cors import CORSMiddleware
  13. from fastapi.responses import FileResponse, StreamingResponse
  14. from pydantic import BaseModel
  15. from starlette.background import BackgroundTask
  16. from open_webui.models.models import Models
  17. from open_webui.config import (
  18. CACHE_DIR,
  19. )
  20. from open_webui.env import (
  21. AIOHTTP_CLIENT_SESSION_SSL,
  22. AIOHTTP_CLIENT_TIMEOUT,
  23. AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
  24. ENABLE_FORWARD_USER_INFO_HEADERS,
  25. BYPASS_MODEL_ACCESS_CONTROL,
  26. )
  27. from open_webui.models.users import UserModel
  28. from open_webui.constants import ERROR_MESSAGES
  29. from open_webui.env import ENV, SRC_LOG_LEVELS
  30. from open_webui.utils.payload import (
  31. apply_model_params_to_body_openai,
  32. apply_model_system_prompt_to_body,
  33. )
  34. from open_webui.utils.misc import (
  35. convert_logit_bias_input_to_json,
  36. )
  37. from open_webui.utils.auth import get_admin_user, get_verified_user
  38. from open_webui.utils.access_control import has_access
  39. log = logging.getLogger(__name__)
  40. log.setLevel(SRC_LOG_LEVELS["OPENAI"])
  41. ##########################################
  42. #
  43. # Utility functions
  44. #
  45. ##########################################
  46. async def send_get_request(url, key=None, user: UserModel = None):
  47. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
  48. try:
  49. async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
  50. async with session.get(
  51. url,
  52. headers={
  53. **({"Authorization": f"Bearer {key}"} if key else {}),
  54. **(
  55. {
  56. "X-OpenWebUI-User-Name": quote(user.name),
  57. "X-OpenWebUI-User-Id": quote(user.id),
  58. "X-OpenWebUI-User-Email": quote(user.email),
  59. "X-OpenWebUI-User-Role": quote(user.role),
  60. }
  61. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  62. else {}
  63. ),
  64. },
  65. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  66. ) as response:
  67. return await response.json()
  68. except Exception as e:
  69. # Handle connection error here
  70. log.error(f"Connection error: {e}")
  71. return None
  72. async def cleanup_response(
  73. response: Optional[aiohttp.ClientResponse],
  74. session: Optional[aiohttp.ClientSession],
  75. ):
  76. if response:
  77. response.close()
  78. if session:
  79. await session.close()
  80. def openai_o_series_handler(payload):
  81. """
  82. Handle "o" series specific parameters
  83. """
  84. if "max_tokens" in payload:
  85. # Convert "max_tokens" to "max_completion_tokens" for all o-series models
  86. payload["max_completion_tokens"] = payload["max_tokens"]
  87. del payload["max_tokens"]
  88. # Handle system role conversion based on model type
  89. if payload["messages"][0]["role"] == "system":
  90. model_lower = payload["model"].lower()
  91. # Legacy models use "user" role instead of "system"
  92. if model_lower.startswith("o1-mini") or model_lower.startswith("o1-preview"):
  93. payload["messages"][0]["role"] = "user"
  94. else:
  95. payload["messages"][0]["role"] = "developer"
  96. return payload
  97. ##########################################
  98. #
  99. # API routes
  100. #
  101. ##########################################
  102. router = APIRouter()
  103. @router.get("/config")
  104. async def get_config(request: Request, user=Depends(get_admin_user)):
  105. return {
  106. "ENABLE_OPENAI_API": request.app.state.config.ENABLE_OPENAI_API,
  107. "OPENAI_API_BASE_URLS": request.app.state.config.OPENAI_API_BASE_URLS,
  108. "OPENAI_API_KEYS": request.app.state.config.OPENAI_API_KEYS,
  109. "OPENAI_API_CONFIGS": request.app.state.config.OPENAI_API_CONFIGS,
  110. }
  111. class OpenAIConfigForm(BaseModel):
  112. ENABLE_OPENAI_API: Optional[bool] = None
  113. OPENAI_API_BASE_URLS: list[str]
  114. OPENAI_API_KEYS: list[str]
  115. OPENAI_API_CONFIGS: dict
  116. @router.post("/config/update")
  117. async def update_config(
  118. request: Request, form_data: OpenAIConfigForm, user=Depends(get_admin_user)
  119. ):
  120. request.app.state.config.ENABLE_OPENAI_API = form_data.ENABLE_OPENAI_API
  121. request.app.state.config.OPENAI_API_BASE_URLS = form_data.OPENAI_API_BASE_URLS
  122. request.app.state.config.OPENAI_API_KEYS = form_data.OPENAI_API_KEYS
  123. # Check if API KEYS length is same than API URLS length
  124. if len(request.app.state.config.OPENAI_API_KEYS) != len(
  125. request.app.state.config.OPENAI_API_BASE_URLS
  126. ):
  127. if len(request.app.state.config.OPENAI_API_KEYS) > len(
  128. request.app.state.config.OPENAI_API_BASE_URLS
  129. ):
  130. request.app.state.config.OPENAI_API_KEYS = (
  131. request.app.state.config.OPENAI_API_KEYS[
  132. : len(request.app.state.config.OPENAI_API_BASE_URLS)
  133. ]
  134. )
  135. else:
  136. request.app.state.config.OPENAI_API_KEYS += [""] * (
  137. len(request.app.state.config.OPENAI_API_BASE_URLS)
  138. - len(request.app.state.config.OPENAI_API_KEYS)
  139. )
  140. request.app.state.config.OPENAI_API_CONFIGS = form_data.OPENAI_API_CONFIGS
  141. # Remove the API configs that are not in the API URLS
  142. keys = list(map(str, range(len(request.app.state.config.OPENAI_API_BASE_URLS))))
  143. request.app.state.config.OPENAI_API_CONFIGS = {
  144. key: value
  145. for key, value in request.app.state.config.OPENAI_API_CONFIGS.items()
  146. if key in keys
  147. }
  148. return {
  149. "ENABLE_OPENAI_API": request.app.state.config.ENABLE_OPENAI_API,
  150. "OPENAI_API_BASE_URLS": request.app.state.config.OPENAI_API_BASE_URLS,
  151. "OPENAI_API_KEYS": request.app.state.config.OPENAI_API_KEYS,
  152. "OPENAI_API_CONFIGS": request.app.state.config.OPENAI_API_CONFIGS,
  153. }
  154. @router.post("/audio/speech")
  155. async def speech(request: Request, user=Depends(get_verified_user)):
  156. idx = None
  157. try:
  158. idx = request.app.state.config.OPENAI_API_BASE_URLS.index(
  159. "https://api.openai.com/v1"
  160. )
  161. body = await request.body()
  162. name = hashlib.sha256(body).hexdigest()
  163. SPEECH_CACHE_DIR = CACHE_DIR / "audio" / "speech"
  164. SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  165. file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
  166. file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
  167. # Check if the file already exists in the cache
  168. if file_path.is_file():
  169. return FileResponse(file_path)
  170. url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
  171. r = None
  172. try:
  173. r = requests.post(
  174. url=f"{url}/audio/speech",
  175. data=body,
  176. headers={
  177. "Content-Type": "application/json",
  178. "Authorization": f"Bearer {request.app.state.config.OPENAI_API_KEYS[idx]}",
  179. **(
  180. {
  181. "HTTP-Referer": "https://openwebui.com/",
  182. "X-Title": "Open WebUI",
  183. }
  184. if "openrouter.ai" in url
  185. else {}
  186. ),
  187. **(
  188. {
  189. "X-OpenWebUI-User-Name": quote(user.name),
  190. "X-OpenWebUI-User-Id": quote(user.id),
  191. "X-OpenWebUI-User-Email": quote(user.email),
  192. "X-OpenWebUI-User-Role": quote(user.role),
  193. }
  194. if ENABLE_FORWARD_USER_INFO_HEADERS
  195. else {}
  196. ),
  197. },
  198. stream=True,
  199. )
  200. r.raise_for_status()
  201. # Save the streaming content to a file
  202. with open(file_path, "wb") as f:
  203. for chunk in r.iter_content(chunk_size=8192):
  204. f.write(chunk)
  205. with open(file_body_path, "w") as f:
  206. json.dump(json.loads(body.decode("utf-8")), f)
  207. # Return the saved file
  208. return FileResponse(file_path)
  209. except Exception as e:
  210. log.exception(e)
  211. detail = None
  212. if r is not None:
  213. try:
  214. res = r.json()
  215. if "error" in res:
  216. detail = f"External: {res['error']}"
  217. except Exception:
  218. detail = f"External: {e}"
  219. raise HTTPException(
  220. status_code=r.status_code if r else 500,
  221. detail=detail if detail else "Open WebUI: Server Connection Error",
  222. )
  223. except ValueError:
  224. raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
  225. async def get_all_models_responses(request: Request, user: UserModel) -> list:
  226. if not request.app.state.config.ENABLE_OPENAI_API:
  227. return []
  228. # Check if API KEYS length is same than API URLS length
  229. num_urls = len(request.app.state.config.OPENAI_API_BASE_URLS)
  230. num_keys = len(request.app.state.config.OPENAI_API_KEYS)
  231. if num_keys != num_urls:
  232. # if there are more keys than urls, remove the extra keys
  233. if num_keys > num_urls:
  234. new_keys = request.app.state.config.OPENAI_API_KEYS[:num_urls]
  235. request.app.state.config.OPENAI_API_KEYS = new_keys
  236. # if there are more urls than keys, add empty keys
  237. else:
  238. request.app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys)
  239. request_tasks = []
  240. for idx, url in enumerate(request.app.state.config.OPENAI_API_BASE_URLS):
  241. if (str(idx) not in request.app.state.config.OPENAI_API_CONFIGS) and (
  242. url not in request.app.state.config.OPENAI_API_CONFIGS # Legacy support
  243. ):
  244. request_tasks.append(
  245. send_get_request(
  246. f"{url}/models",
  247. request.app.state.config.OPENAI_API_KEYS[idx],
  248. user=user,
  249. )
  250. )
  251. else:
  252. api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
  253. str(idx),
  254. request.app.state.config.OPENAI_API_CONFIGS.get(
  255. url, {}
  256. ), # Legacy support
  257. )
  258. enable = api_config.get("enable", True)
  259. model_ids = api_config.get("model_ids", [])
  260. if enable:
  261. if len(model_ids) == 0:
  262. request_tasks.append(
  263. send_get_request(
  264. f"{url}/models",
  265. request.app.state.config.OPENAI_API_KEYS[idx],
  266. user=user,
  267. )
  268. )
  269. else:
  270. model_list = {
  271. "object": "list",
  272. "data": [
  273. {
  274. "id": model_id,
  275. "name": model_id,
  276. "owned_by": "openai",
  277. "openai": {"id": model_id},
  278. "urlIdx": idx,
  279. }
  280. for model_id in model_ids
  281. ],
  282. }
  283. request_tasks.append(
  284. asyncio.ensure_future(asyncio.sleep(0, model_list))
  285. )
  286. else:
  287. request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
  288. responses = await asyncio.gather(*request_tasks)
  289. for idx, response in enumerate(responses):
  290. if response:
  291. url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
  292. api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
  293. str(idx),
  294. request.app.state.config.OPENAI_API_CONFIGS.get(
  295. url, {}
  296. ), # Legacy support
  297. )
  298. connection_type = api_config.get("connection_type", "external")
  299. prefix_id = api_config.get("prefix_id", None)
  300. tags = api_config.get("tags", [])
  301. for model in (
  302. response if isinstance(response, list) else response.get("data", [])
  303. ):
  304. if prefix_id:
  305. model["id"] = f"{prefix_id}.{model['id']}"
  306. if tags:
  307. model["tags"] = tags
  308. if connection_type:
  309. model["connection_type"] = connection_type
  310. log.debug(f"get_all_models:responses() {responses}")
  311. return responses
  312. async def get_filtered_models(models, user):
  313. # Filter models based on user access control
  314. filtered_models = []
  315. for model in models.get("data", []):
  316. model_info = Models.get_model_by_id(model["id"])
  317. if model_info:
  318. if user.id == model_info.user_id or has_access(
  319. user.id, type="read", access_control=model_info.access_control
  320. ):
  321. filtered_models.append(model)
  322. return filtered_models
  323. @cached(ttl=1)
  324. async def get_all_models(request: Request, user: UserModel) -> dict[str, list]:
  325. log.info("get_all_models()")
  326. if not request.app.state.config.ENABLE_OPENAI_API:
  327. return {"data": []}
  328. responses = await get_all_models_responses(request, user=user)
  329. def extract_data(response):
  330. if response and "data" in response:
  331. return response["data"]
  332. if isinstance(response, list):
  333. return response
  334. return None
  335. def merge_models_lists(model_lists):
  336. log.debug(f"merge_models_lists {model_lists}")
  337. merged_list = []
  338. for idx, models in enumerate(model_lists):
  339. if models is not None and "error" not in models:
  340. merged_list.extend(
  341. [
  342. {
  343. **model,
  344. "name": model.get("name", model["id"]),
  345. "owned_by": "openai",
  346. "openai": model,
  347. "connection_type": model.get("connection_type", "external"),
  348. "urlIdx": idx,
  349. }
  350. for model in models
  351. if (model.get("id") or model.get("name"))
  352. and (
  353. "api.openai.com"
  354. not in request.app.state.config.OPENAI_API_BASE_URLS[idx]
  355. or not any(
  356. name in model["id"]
  357. for name in [
  358. "babbage",
  359. "dall-e",
  360. "davinci",
  361. "embedding",
  362. "tts",
  363. "whisper",
  364. ]
  365. )
  366. )
  367. ]
  368. )
  369. return merged_list
  370. models = {"data": merge_models_lists(map(extract_data, responses))}
  371. log.debug(f"models: {models}")
  372. request.app.state.OPENAI_MODELS = {model["id"]: model for model in models["data"]}
  373. return models
  374. @router.get("/models")
  375. @router.get("/models/{url_idx}")
  376. async def get_models(
  377. request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user)
  378. ):
  379. models = {
  380. "data": [],
  381. }
  382. if url_idx is None:
  383. models = await get_all_models(request, user=user)
  384. else:
  385. url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx]
  386. key = request.app.state.config.OPENAI_API_KEYS[url_idx]
  387. api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
  388. str(url_idx),
  389. request.app.state.config.OPENAI_API_CONFIGS.get(url, {}), # Legacy support
  390. )
  391. r = None
  392. async with aiohttp.ClientSession(
  393. trust_env=True,
  394. timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
  395. ) as session:
  396. try:
  397. headers = {
  398. "Content-Type": "application/json",
  399. **(
  400. {
  401. "X-OpenWebUI-User-Name": quote(user.name),
  402. "X-OpenWebUI-User-Id": quote(user.id),
  403. "X-OpenWebUI-User-Email": quote(user.email),
  404. "X-OpenWebUI-User-Role": quote(user.role),
  405. }
  406. if ENABLE_FORWARD_USER_INFO_HEADERS
  407. else {}
  408. ),
  409. }
  410. if api_config.get("azure", False):
  411. models = {
  412. "data": api_config.get("model_ids", []) or [],
  413. "object": "list",
  414. }
  415. else:
  416. headers["Authorization"] = f"Bearer {key}"
  417. async with session.get(
  418. f"{url}/models",
  419. headers=headers,
  420. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  421. ) as r:
  422. if r.status != 200:
  423. # Extract response error details if available
  424. error_detail = f"HTTP Error: {r.status}"
  425. res = await r.json()
  426. if "error" in res:
  427. error_detail = f"External Error: {res['error']}"
  428. raise Exception(error_detail)
  429. response_data = await r.json()
  430. # Check if we're calling OpenAI API based on the URL
  431. if "api.openai.com" in url:
  432. # Filter models according to the specified conditions
  433. response_data["data"] = [
  434. model
  435. for model in response_data.get("data", [])
  436. if not any(
  437. name in model["id"]
  438. for name in [
  439. "babbage",
  440. "dall-e",
  441. "davinci",
  442. "embedding",
  443. "tts",
  444. "whisper",
  445. ]
  446. )
  447. ]
  448. models = response_data
  449. except aiohttp.ClientError as e:
  450. # ClientError covers all aiohttp requests issues
  451. log.exception(f"Client error: {str(e)}")
  452. raise HTTPException(
  453. status_code=500, detail="Open WebUI: Server Connection Error"
  454. )
  455. except Exception as e:
  456. log.exception(f"Unexpected error: {e}")
  457. error_detail = f"Unexpected error: {str(e)}"
  458. raise HTTPException(status_code=500, detail=error_detail)
  459. if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
  460. models["data"] = await get_filtered_models(models, user)
  461. return models
  462. class ConnectionVerificationForm(BaseModel):
  463. url: str
  464. key: str
  465. config: Optional[dict] = None
  466. @router.post("/verify")
  467. async def verify_connection(
  468. form_data: ConnectionVerificationForm, user=Depends(get_admin_user)
  469. ):
  470. url = form_data.url
  471. key = form_data.key
  472. api_config = form_data.config or {}
  473. async with aiohttp.ClientSession(
  474. trust_env=True,
  475. timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
  476. ) as session:
  477. try:
  478. headers = {
  479. "Content-Type": "application/json",
  480. **(
  481. {
  482. "X-OpenWebUI-User-Name": quote(user.name),
  483. "X-OpenWebUI-User-Id": quote(user.id),
  484. "X-OpenWebUI-User-Email": quote(user.email),
  485. "X-OpenWebUI-User-Role": quote(user.role),
  486. }
  487. if ENABLE_FORWARD_USER_INFO_HEADERS
  488. else {}
  489. ),
  490. }
  491. if api_config.get("azure", False):
  492. headers["api-key"] = key
  493. api_version = api_config.get("api_version", "") or "2023-03-15-preview"
  494. async with session.get(
  495. url=f"{url}/openai/models?api-version={api_version}",
  496. headers=headers,
  497. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  498. ) as r:
  499. if r.status != 200:
  500. # Extract response error details if available
  501. error_detail = f"HTTP Error: {r.status}"
  502. res = await r.json()
  503. if "error" in res:
  504. error_detail = f"External Error: {res['error']}"
  505. raise Exception(error_detail)
  506. response_data = await r.json()
  507. return response_data
  508. else:
  509. headers["Authorization"] = f"Bearer {key}"
  510. async with session.get(
  511. f"{url}/models",
  512. headers=headers,
  513. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  514. ) as r:
  515. if r.status != 200:
  516. # Extract response error details if available
  517. error_detail = f"HTTP Error: {r.status}"
  518. res = await r.json()
  519. if "error" in res:
  520. error_detail = f"External Error: {res['error']}"
  521. raise Exception(error_detail)
  522. response_data = await r.json()
  523. return response_data
  524. except aiohttp.ClientError as e:
  525. # ClientError covers all aiohttp requests issues
  526. log.exception(f"Client error: {str(e)}")
  527. raise HTTPException(
  528. status_code=500, detail="Open WebUI: Server Connection Error"
  529. )
  530. except Exception as e:
  531. log.exception(f"Unexpected error: {e}")
  532. error_detail = f"Unexpected error: {str(e)}"
  533. raise HTTPException(status_code=500, detail=error_detail)
  534. def get_azure_allowed_params(api_version: str) -> set[str]:
  535. allowed_params = {
  536. "messages",
  537. "temperature",
  538. "role",
  539. "content",
  540. "contentPart",
  541. "contentPartImage",
  542. "enhancements",
  543. "dataSources",
  544. "n",
  545. "stream",
  546. "stop",
  547. "max_tokens",
  548. "presence_penalty",
  549. "frequency_penalty",
  550. "logit_bias",
  551. "user",
  552. "function_call",
  553. "functions",
  554. "tools",
  555. "tool_choice",
  556. "top_p",
  557. "log_probs",
  558. "top_logprobs",
  559. "response_format",
  560. "seed",
  561. "max_completion_tokens",
  562. }
  563. if api_version >= "2024-09-01-preview":
  564. allowed_params.add("stream_options")
  565. return allowed_params
  566. def convert_to_azure_payload(
  567. url,
  568. payload: dict,
  569. api_version: str
  570. ):
  571. model = payload.get("model", "")
  572. # Filter allowed parameters based on Azure OpenAI API
  573. allowed_params = get_azure_allowed_params(api_version)
  574. # Special handling for o-series models
  575. if model.startswith("o") and model.endswith("-mini"):
  576. # Convert max_tokens to max_completion_tokens for o-series models
  577. if "max_tokens" in payload:
  578. payload["max_completion_tokens"] = payload["max_tokens"]
  579. del payload["max_tokens"]
  580. # Remove temperature if not 1 for o-series models
  581. if "temperature" in payload and payload["temperature"] != 1:
  582. log.debug(
  583. f"Removing temperature parameter for o-series model {model} as only default value (1) is supported"
  584. )
  585. del payload["temperature"]
  586. # Filter out unsupported parameters
  587. payload = {k: v for k, v in payload.items() if k in allowed_params}
  588. url = f"{url}/openai/deployments/{model}"
  589. return url, payload
  590. @router.post("/chat/completions")
  591. async def generate_chat_completion(
  592. request: Request,
  593. form_data: dict,
  594. user=Depends(get_verified_user),
  595. bypass_filter: Optional[bool] = False,
  596. ):
  597. if BYPASS_MODEL_ACCESS_CONTROL:
  598. bypass_filter = True
  599. idx = 0
  600. payload = {**form_data}
  601. metadata = payload.pop("metadata", None)
  602. model_id = form_data.get("model")
  603. model_info = Models.get_model_by_id(model_id)
  604. # Check model info and override the payload
  605. if model_info:
  606. if model_info.base_model_id:
  607. payload["model"] = model_info.base_model_id
  608. model_id = model_info.base_model_id
  609. params = model_info.params.model_dump()
  610. if params:
  611. system = params.pop("system", None)
  612. payload = apply_model_params_to_body_openai(params, payload)
  613. payload = apply_model_system_prompt_to_body(system, payload, metadata, user)
  614. # Check if user has access to the model
  615. if not bypass_filter and user.role == "user":
  616. if not (
  617. user.id == model_info.user_id
  618. or has_access(
  619. user.id, type="read", access_control=model_info.access_control
  620. )
  621. ):
  622. raise HTTPException(
  623. status_code=403,
  624. detail="Model not found",
  625. )
  626. elif not bypass_filter:
  627. if user.role != "admin":
  628. raise HTTPException(
  629. status_code=403,
  630. detail="Model not found",
  631. )
  632. await get_all_models(request, user=user)
  633. model = request.app.state.OPENAI_MODELS.get(model_id)
  634. if model:
  635. idx = model["urlIdx"]
  636. else:
  637. raise HTTPException(
  638. status_code=404,
  639. detail="Model not found",
  640. )
  641. # Get the API config for the model
  642. api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
  643. str(idx),
  644. request.app.state.config.OPENAI_API_CONFIGS.get(
  645. request.app.state.config.OPENAI_API_BASE_URLS[idx], {}
  646. ), # Legacy support
  647. )
  648. prefix_id = api_config.get("prefix_id", None)
  649. if prefix_id:
  650. payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
  651. # Add user info to the payload if the model is a pipeline
  652. if "pipeline" in model and model.get("pipeline"):
  653. payload["user"] = {
  654. "name": user.name,
  655. "id": user.id,
  656. "email": user.email,
  657. "role": user.role,
  658. }
  659. url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
  660. key = request.app.state.config.OPENAI_API_KEYS[idx]
  661. # Check if model is from "o" series
  662. is_o_series = payload["model"].lower().startswith(("o1", "o3", "o4"))
  663. if is_o_series:
  664. payload = openai_o_series_handler(payload)
  665. elif "api.openai.com" not in url:
  666. # Remove "max_completion_tokens" from the payload for backward compatibility
  667. if "max_completion_tokens" in payload:
  668. payload["max_tokens"] = payload["max_completion_tokens"]
  669. del payload["max_completion_tokens"]
  670. if "max_tokens" in payload and "max_completion_tokens" in payload:
  671. del payload["max_tokens"]
  672. # Convert the modified body back to JSON
  673. if "logit_bias" in payload:
  674. payload["logit_bias"] = json.loads(
  675. convert_logit_bias_input_to_json(payload["logit_bias"])
  676. )
  677. headers = {
  678. "Content-Type": "application/json",
  679. **(
  680. {
  681. "HTTP-Referer": "https://openwebui.com/",
  682. "X-Title": "Open WebUI",
  683. }
  684. if "openrouter.ai" in url
  685. else {}
  686. ),
  687. **(
  688. {
  689. "X-OpenWebUI-User-Name": quote(user.name),
  690. "X-OpenWebUI-User-Id": quote(user.id),
  691. "X-OpenWebUI-User-Email": quote(user.email),
  692. "X-OpenWebUI-User-Role": quote(user.role),
  693. }
  694. if ENABLE_FORWARD_USER_INFO_HEADERS
  695. else {}
  696. ),
  697. }
  698. if api_config.get("azure", False):
  699. api_version = api_config.get("api_version", "2023-03-15-preview")
  700. request_url, payload = convert_to_azure_payload(url, payload, api_version)
  701. headers["api-key"] = key
  702. headers["api-version"] = api_version
  703. request_url = f"{request_url}/chat/completions?api-version={api_version}"
  704. else:
  705. request_url = f"{url}/chat/completions"
  706. headers["Authorization"] = f"Bearer {key}"
  707. payload = json.dumps(payload)
  708. r = None
  709. session = None
  710. streaming = False
  711. response = None
  712. try:
  713. session = aiohttp.ClientSession(
  714. trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
  715. )
  716. r = await session.request(
  717. method="POST",
  718. url=request_url,
  719. data=payload,
  720. headers=headers,
  721. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  722. )
  723. # Check if response is SSE
  724. if "text/event-stream" in r.headers.get("Content-Type", ""):
  725. streaming = True
  726. return StreamingResponse(
  727. r.content,
  728. status_code=r.status,
  729. headers=dict(r.headers),
  730. background=BackgroundTask(
  731. cleanup_response, response=r, session=session
  732. ),
  733. )
  734. else:
  735. try:
  736. response = await r.json()
  737. except Exception as e:
  738. log.error(e)
  739. response = await r.text()
  740. r.raise_for_status()
  741. return response
  742. except Exception as e:
  743. log.exception(e)
  744. detail = None
  745. if isinstance(response, dict):
  746. if "error" in response:
  747. detail = f"{response['error']['message'] if 'message' in response['error'] else response['error']}"
  748. elif isinstance(response, str):
  749. detail = response
  750. raise HTTPException(
  751. status_code=r.status if r else 500,
  752. detail=detail if detail else "Open WebUI: Server Connection Error",
  753. )
  754. finally:
  755. if not streaming and session:
  756. if r:
  757. r.close()
  758. await session.close()
  759. async def embeddings(request: Request, form_data: dict, user):
  760. """
  761. Calls the embeddings endpoint for OpenAI-compatible providers.
  762. Args:
  763. request (Request): The FastAPI request context.
  764. form_data (dict): OpenAI-compatible embeddings payload.
  765. user (UserModel): The authenticated user.
  766. Returns:
  767. dict: OpenAI-compatible embeddings response.
  768. """
  769. idx = 0
  770. # Prepare payload/body
  771. body = json.dumps(form_data)
  772. # Find correct backend url/key based on model
  773. await get_all_models(request, user=user)
  774. model_id = form_data.get("model")
  775. models = request.app.state.OPENAI_MODELS
  776. if model_id in models:
  777. idx = models[model_id]["urlIdx"]
  778. url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
  779. key = request.app.state.config.OPENAI_API_KEYS[idx]
  780. r = None
  781. session = None
  782. streaming = False
  783. try:
  784. session = aiohttp.ClientSession(trust_env=True)
  785. r = await session.request(
  786. method="POST",
  787. url=f"{url}/embeddings",
  788. data=body,
  789. headers={
  790. "Authorization": f"Bearer {key}",
  791. "Content-Type": "application/json",
  792. **(
  793. {
  794. "X-OpenWebUI-User-Name": quote(user.name),
  795. "X-OpenWebUI-User-Id": quote(user.id),
  796. "X-OpenWebUI-User-Email": quote(user.email),
  797. "X-OpenWebUI-User-Role": quote(user.role),
  798. }
  799. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  800. else {}
  801. ),
  802. },
  803. )
  804. r.raise_for_status()
  805. if "text/event-stream" in r.headers.get("Content-Type", ""):
  806. streaming = True
  807. return StreamingResponse(
  808. r.content,
  809. status_code=r.status,
  810. headers=dict(r.headers),
  811. background=BackgroundTask(
  812. cleanup_response, response=r, session=session
  813. ),
  814. )
  815. else:
  816. response_data = await r.json()
  817. return response_data
  818. except Exception as e:
  819. log.exception(e)
  820. detail = None
  821. if r is not None:
  822. try:
  823. res = await r.json()
  824. if "error" in res:
  825. detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
  826. except Exception:
  827. detail = f"External: {e}"
  828. raise HTTPException(
  829. status_code=r.status if r else 500,
  830. detail=detail if detail else "Open WebUI: Server Connection Error",
  831. )
  832. finally:
  833. if not streaming and session:
  834. if r:
  835. r.close()
  836. await session.close()
  837. @router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
  838. async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
  839. """
  840. Deprecated: proxy all requests to OpenAI API
  841. """
  842. body = await request.body()
  843. idx = 0
  844. url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
  845. key = request.app.state.config.OPENAI_API_KEYS[idx]
  846. api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
  847. str(idx),
  848. request.app.state.config.OPENAI_API_CONFIGS.get(
  849. request.app.state.config.OPENAI_API_BASE_URLS[idx], {}
  850. ), # Legacy support
  851. )
  852. r = None
  853. session = None
  854. streaming = False
  855. try:
  856. headers = {
  857. "Content-Type": "application/json",
  858. **(
  859. {
  860. "X-OpenWebUI-User-Name": quote(user.name),
  861. "X-OpenWebUI-User-Id": quote(user.id),
  862. "X-OpenWebUI-User-Email": quote(user.email),
  863. "X-OpenWebUI-User-Role": quote(user.role),
  864. }
  865. if ENABLE_FORWARD_USER_INFO_HEADERS
  866. else {}
  867. ),
  868. }
  869. if api_config.get("azure", False):
  870. api_version = api_config.get("api_version", "2023-03-15-preview")
  871. headers["api-key"] = key
  872. headers["api-version"] = api_version
  873. payload = json.loads(body)
  874. url, payload = convert_to_azure_payload(url, payload, api_version)
  875. body = json.dumps(payload).encode()
  876. request_url = f"{url}/{path}?api-version={api_version}"
  877. else:
  878. headers["Authorization"] = f"Bearer {key}"
  879. request_url = f"{url}/{path}"
  880. session = aiohttp.ClientSession(trust_env=True)
  881. r = await session.request(
  882. method=request.method,
  883. url=request_url,
  884. data=body,
  885. headers=headers,
  886. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  887. )
  888. r.raise_for_status()
  889. # Check if response is SSE
  890. if "text/event-stream" in r.headers.get("Content-Type", ""):
  891. streaming = True
  892. return StreamingResponse(
  893. r.content,
  894. status_code=r.status,
  895. headers=dict(r.headers),
  896. background=BackgroundTask(
  897. cleanup_response, response=r, session=session
  898. ),
  899. )
  900. else:
  901. response_data = await r.json()
  902. return response_data
  903. except Exception as e:
  904. log.exception(e)
  905. detail = None
  906. if r is not None:
  907. try:
  908. res = await r.json()
  909. log.error(res)
  910. if "error" in res:
  911. detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
  912. except Exception:
  913. detail = f"External: {e}"
  914. raise HTTPException(
  915. status_code=r.status if r else 500,
  916. detail=detail if detail else "Open WebUI: Server Connection Error",
  917. )
  918. finally:
  919. if not streaming and session:
  920. if r:
  921. r.close()
  922. await session.close()