ollama.py 57 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812
  1. # TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
  2. # Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
  3. # least connections, or least response time for better resource utilization and performance optimization.
  4. import asyncio
  5. import json
  6. import logging
  7. import os
  8. import random
  9. import re
  10. import time
  11. from datetime import datetime
  12. from typing import Optional, Union
  13. from urllib.parse import urlparse
  14. import aiohttp
  15. from aiocache import cached
  16. import requests
  17. from open_webui.models.chats import Chats
  18. from open_webui.models.users import UserModel
  19. from open_webui.env import (
  20. ENABLE_FORWARD_USER_INFO_HEADERS,
  21. )
  22. from fastapi import (
  23. Depends,
  24. FastAPI,
  25. File,
  26. HTTPException,
  27. Request,
  28. UploadFile,
  29. APIRouter,
  30. )
  31. from fastapi.middleware.cors import CORSMiddleware
  32. from fastapi.responses import StreamingResponse
  33. from pydantic import BaseModel, ConfigDict, validator
  34. from starlette.background import BackgroundTask
  35. from open_webui.models.models import Models
  36. from open_webui.utils.misc import (
  37. calculate_sha256,
  38. )
  39. from open_webui.utils.payload import (
  40. apply_model_params_to_body_ollama,
  41. apply_model_params_to_body_openai,
  42. apply_model_system_prompt_to_body,
  43. )
  44. from open_webui.utils.auth import get_admin_user, get_verified_user
  45. from open_webui.utils.access_control import has_access
  46. from open_webui.config import (
  47. UPLOAD_DIR,
  48. )
  49. from open_webui.env import (
  50. ENV,
  51. SRC_LOG_LEVELS,
  52. AIOHTTP_CLIENT_SESSION_SSL,
  53. AIOHTTP_CLIENT_TIMEOUT,
  54. AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
  55. BYPASS_MODEL_ACCESS_CONTROL,
  56. )
  57. from open_webui.constants import ERROR_MESSAGES
  58. log = logging.getLogger(__name__)
  59. log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
  60. ##########################################
  61. #
  62. # Utility functions
  63. #
  64. ##########################################
  65. async def send_get_request(url, key=None, user: UserModel = None):
  66. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
  67. try:
  68. async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
  69. async with session.get(
  70. url,
  71. headers={
  72. "Content-Type": "application/json",
  73. **({"Authorization": f"Bearer {key}"} if key else {}),
  74. **(
  75. {
  76. "X-OpenWebUI-User-Name": user.name,
  77. "X-OpenWebUI-User-Id": user.id,
  78. "X-OpenWebUI-User-Email": user.email,
  79. "X-OpenWebUI-User-Role": user.role,
  80. }
  81. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  82. else {}
  83. ),
  84. },
  85. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  86. ) as response:
  87. return await response.json()
  88. except Exception as e:
  89. # Handle connection error here
  90. log.error(f"Connection error: {e}")
  91. return None
  92. async def cleanup_response(
  93. response: Optional[aiohttp.ClientResponse],
  94. session: Optional[aiohttp.ClientSession],
  95. ):
  96. if response:
  97. response.close()
  98. if session:
  99. await session.close()
  100. async def send_post_request(
  101. url: str,
  102. payload: Union[str, bytes],
  103. stream: bool = True,
  104. key: Optional[str] = None,
  105. content_type: Optional[str] = None,
  106. user: UserModel = None,
  107. ):
  108. r = None
  109. try:
  110. session = aiohttp.ClientSession(
  111. trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
  112. )
  113. r = await session.post(
  114. url,
  115. data=payload,
  116. headers={
  117. "Content-Type": "application/json",
  118. **({"Authorization": f"Bearer {key}"} if key else {}),
  119. **(
  120. {
  121. "X-OpenWebUI-User-Name": user.name,
  122. "X-OpenWebUI-User-Id": user.id,
  123. "X-OpenWebUI-User-Email": user.email,
  124. "X-OpenWebUI-User-Role": user.role,
  125. }
  126. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  127. else {}
  128. ),
  129. },
  130. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  131. )
  132. if r.ok is False:
  133. try:
  134. res = await r.json()
  135. await cleanup_response(r, session)
  136. if "error" in res:
  137. raise HTTPException(status_code=r.status, detail=res["error"])
  138. except HTTPException as e:
  139. raise e # Re-raise HTTPException to be handled by FastAPI
  140. except Exception as e:
  141. log.error(f"Failed to parse error response: {e}")
  142. raise HTTPException(
  143. status_code=r.status,
  144. detail=f"Open WebUI: Server Connection Error",
  145. )
  146. r.raise_for_status() # Raises an error for bad responses (4xx, 5xx)
  147. if stream:
  148. response_headers = dict(r.headers)
  149. if content_type:
  150. response_headers["Content-Type"] = content_type
  151. return StreamingResponse(
  152. r.content,
  153. status_code=r.status,
  154. headers=response_headers,
  155. background=BackgroundTask(
  156. cleanup_response, response=r, session=session
  157. ),
  158. )
  159. else:
  160. res = await r.json()
  161. await cleanup_response(r, session)
  162. return res
  163. except HTTPException as e:
  164. raise e # Re-raise HTTPException to be handled by FastAPI
  165. except Exception as e:
  166. detail = f"Ollama: {e}"
  167. raise HTTPException(
  168. status_code=r.status if r else 500,
  169. detail=detail if e else "Open WebUI: Server Connection Error",
  170. )
  171. def get_api_key(idx, url, configs):
  172. parsed_url = urlparse(url)
  173. base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
  174. return configs.get(str(idx), configs.get(base_url, {})).get(
  175. "key", None
  176. ) # Legacy support
  177. ##########################################
  178. #
  179. # API routes
  180. #
  181. ##########################################
  182. router = APIRouter()
  183. @router.head("/")
  184. @router.get("/")
  185. async def get_status():
  186. return {"status": True}
  187. class ConnectionVerificationForm(BaseModel):
  188. url: str
  189. key: Optional[str] = None
  190. @router.post("/verify")
  191. async def verify_connection(
  192. form_data: ConnectionVerificationForm, user=Depends(get_admin_user)
  193. ):
  194. url = form_data.url
  195. key = form_data.key
  196. async with aiohttp.ClientSession(
  197. trust_env=True,
  198. timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
  199. ) as session:
  200. try:
  201. async with session.get(
  202. f"{url}/api/version",
  203. headers={
  204. **({"Authorization": f"Bearer {key}"} if key else {}),
  205. **(
  206. {
  207. "X-OpenWebUI-User-Name": user.name,
  208. "X-OpenWebUI-User-Id": user.id,
  209. "X-OpenWebUI-User-Email": user.email,
  210. "X-OpenWebUI-User-Role": user.role,
  211. }
  212. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  213. else {}
  214. ),
  215. },
  216. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  217. ) as r:
  218. if r.status != 200:
  219. detail = f"HTTP Error: {r.status}"
  220. res = await r.json()
  221. if "error" in res:
  222. detail = f"External Error: {res['error']}"
  223. raise Exception(detail)
  224. data = await r.json()
  225. return data
  226. except aiohttp.ClientError as e:
  227. log.exception(f"Client error: {str(e)}")
  228. raise HTTPException(
  229. status_code=500, detail="Open WebUI: Server Connection Error"
  230. )
  231. except Exception as e:
  232. log.exception(f"Unexpected error: {e}")
  233. error_detail = f"Unexpected error: {str(e)}"
  234. raise HTTPException(status_code=500, detail=error_detail)
  235. @router.get("/config")
  236. async def get_config(request: Request, user=Depends(get_admin_user)):
  237. return {
  238. "ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API,
  239. "OLLAMA_BASE_URLS": request.app.state.config.OLLAMA_BASE_URLS,
  240. "OLLAMA_API_CONFIGS": request.app.state.config.OLLAMA_API_CONFIGS,
  241. }
  242. class OllamaConfigForm(BaseModel):
  243. ENABLE_OLLAMA_API: Optional[bool] = None
  244. OLLAMA_BASE_URLS: list[str]
  245. OLLAMA_API_CONFIGS: dict
  246. @router.post("/config/update")
  247. async def update_config(
  248. request: Request, form_data: OllamaConfigForm, user=Depends(get_admin_user)
  249. ):
  250. request.app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API
  251. request.app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS
  252. request.app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS
  253. # Remove the API configs that are not in the API URLS
  254. keys = list(map(str, range(len(request.app.state.config.OLLAMA_BASE_URLS))))
  255. request.app.state.config.OLLAMA_API_CONFIGS = {
  256. key: value
  257. for key, value in request.app.state.config.OLLAMA_API_CONFIGS.items()
  258. if key in keys
  259. }
  260. return {
  261. "ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API,
  262. "OLLAMA_BASE_URLS": request.app.state.config.OLLAMA_BASE_URLS,
  263. "OLLAMA_API_CONFIGS": request.app.state.config.OLLAMA_API_CONFIGS,
  264. }
  265. def merge_ollama_models_lists(model_lists):
  266. merged_models = {}
  267. for idx, model_list in enumerate(model_lists):
  268. if model_list is not None:
  269. for model in model_list:
  270. id = model["model"]
  271. if id not in merged_models:
  272. model["urls"] = [idx]
  273. merged_models[id] = model
  274. else:
  275. merged_models[id]["urls"].append(idx)
  276. return list(merged_models.values())
  277. @cached(ttl=1)
  278. async def get_all_models(request: Request, user: UserModel = None):
  279. log.info("get_all_models()")
  280. if request.app.state.config.ENABLE_OLLAMA_API:
  281. request_tasks = []
  282. for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
  283. if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and (
  284. url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support
  285. ):
  286. request_tasks.append(send_get_request(f"{url}/api/tags", user=user))
  287. else:
  288. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  289. str(idx),
  290. request.app.state.config.OLLAMA_API_CONFIGS.get(
  291. url, {}
  292. ), # Legacy support
  293. )
  294. enable = api_config.get("enable", True)
  295. key = api_config.get("key", None)
  296. if enable:
  297. request_tasks.append(
  298. send_get_request(f"{url}/api/tags", key, user=user)
  299. )
  300. else:
  301. request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
  302. responses = await asyncio.gather(*request_tasks)
  303. for idx, response in enumerate(responses):
  304. if response:
  305. url = request.app.state.config.OLLAMA_BASE_URLS[idx]
  306. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  307. str(idx),
  308. request.app.state.config.OLLAMA_API_CONFIGS.get(
  309. url, {}
  310. ), # Legacy support
  311. )
  312. connection_type = api_config.get("connection_type", "local")
  313. prefix_id = api_config.get("prefix_id", None)
  314. tags = api_config.get("tags", [])
  315. model_ids = api_config.get("model_ids", [])
  316. if len(model_ids) != 0 and "models" in response:
  317. response["models"] = list(
  318. filter(
  319. lambda model: model["model"] in model_ids,
  320. response["models"],
  321. )
  322. )
  323. for model in response.get("models", []):
  324. if prefix_id:
  325. model["model"] = f"{prefix_id}.{model['model']}"
  326. if tags:
  327. model["tags"] = tags
  328. if connection_type:
  329. model["connection_type"] = connection_type
  330. models = {
  331. "models": merge_ollama_models_lists(
  332. map(
  333. lambda response: response.get("models", []) if response else None,
  334. responses,
  335. )
  336. )
  337. }
  338. try:
  339. loaded_models = await get_ollama_loaded_models(request, user=user)
  340. expires_map = {
  341. m["name"]: m["expires_at"]
  342. for m in loaded_models["models"]
  343. if "expires_at" in m
  344. }
  345. for m in models["models"]:
  346. if m["name"] in expires_map:
  347. # Parse ISO8601 datetime with offset, get unix timestamp as int
  348. dt = datetime.fromisoformat(expires_map[m["name"]])
  349. m["expires_at"] = int(dt.timestamp())
  350. except Exception as e:
  351. log.debug(f"Failed to get loaded models: {e}")
  352. else:
  353. models = {"models": []}
  354. request.app.state.OLLAMA_MODELS = {
  355. model["model"]: model for model in models["models"]
  356. }
  357. return models
  358. async def get_filtered_models(models, user):
  359. # Filter models based on user access control
  360. filtered_models = []
  361. for model in models.get("models", []):
  362. model_info = Models.get_model_by_id(model["model"])
  363. if model_info:
  364. if user.id == model_info.user_id or has_access(
  365. user.id, type="read", access_control=model_info.access_control
  366. ):
  367. filtered_models.append(model)
  368. return filtered_models
  369. @router.get("/api/tags")
  370. @router.get("/api/tags/{url_idx}")
  371. async def get_ollama_tags(
  372. request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user)
  373. ):
  374. models = []
  375. if url_idx is None:
  376. models = await get_all_models(request, user=user)
  377. else:
  378. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  379. key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  380. r = None
  381. try:
  382. r = requests.request(
  383. method="GET",
  384. url=f"{url}/api/tags",
  385. headers={
  386. **({"Authorization": f"Bearer {key}"} if key else {}),
  387. **(
  388. {
  389. "X-OpenWebUI-User-Name": user.name,
  390. "X-OpenWebUI-User-Id": user.id,
  391. "X-OpenWebUI-User-Email": user.email,
  392. "X-OpenWebUI-User-Role": user.role,
  393. }
  394. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  395. else {}
  396. ),
  397. },
  398. )
  399. r.raise_for_status()
  400. models = r.json()
  401. except Exception as e:
  402. log.exception(e)
  403. detail = None
  404. if r is not None:
  405. try:
  406. res = r.json()
  407. if "error" in res:
  408. detail = f"Ollama: {res['error']}"
  409. except Exception:
  410. detail = f"Ollama: {e}"
  411. raise HTTPException(
  412. status_code=r.status_code if r else 500,
  413. detail=detail if detail else "Open WebUI: Server Connection Error",
  414. )
  415. if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
  416. models["models"] = await get_filtered_models(models, user)
  417. return models
  418. @router.get("/api/ps")
  419. async def get_ollama_loaded_models(request: Request, user=Depends(get_admin_user)):
  420. """
  421. List models that are currently loaded into Ollama memory, and which node they are loaded on.
  422. """
  423. if request.app.state.config.ENABLE_OLLAMA_API:
  424. request_tasks = []
  425. for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
  426. if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and (
  427. url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support
  428. ):
  429. request_tasks.append(send_get_request(f"{url}/api/ps", user=user))
  430. else:
  431. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  432. str(idx),
  433. request.app.state.config.OLLAMA_API_CONFIGS.get(
  434. url, {}
  435. ), # Legacy support
  436. )
  437. enable = api_config.get("enable", True)
  438. key = api_config.get("key", None)
  439. if enable:
  440. request_tasks.append(
  441. send_get_request(f"{url}/api/ps", key, user=user)
  442. )
  443. else:
  444. request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
  445. responses = await asyncio.gather(*request_tasks)
  446. for idx, response in enumerate(responses):
  447. if response:
  448. url = request.app.state.config.OLLAMA_BASE_URLS[idx]
  449. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  450. str(idx),
  451. request.app.state.config.OLLAMA_API_CONFIGS.get(
  452. url, {}
  453. ), # Legacy support
  454. )
  455. prefix_id = api_config.get("prefix_id", None)
  456. for model in response.get("models", []):
  457. if prefix_id:
  458. model["model"] = f"{prefix_id}.{model['model']}"
  459. models = {
  460. "models": merge_ollama_models_lists(
  461. map(
  462. lambda response: response.get("models", []) if response else None,
  463. responses,
  464. )
  465. )
  466. }
  467. else:
  468. models = {"models": []}
  469. return models
  470. @router.get("/api/version")
  471. @router.get("/api/version/{url_idx}")
  472. async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
  473. if request.app.state.config.ENABLE_OLLAMA_API:
  474. if url_idx is None:
  475. # returns lowest version
  476. request_tasks = []
  477. for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
  478. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  479. str(idx),
  480. request.app.state.config.OLLAMA_API_CONFIGS.get(
  481. url, {}
  482. ), # Legacy support
  483. )
  484. enable = api_config.get("enable", True)
  485. key = api_config.get("key", None)
  486. if enable:
  487. request_tasks.append(
  488. send_get_request(
  489. f"{url}/api/version",
  490. key,
  491. )
  492. )
  493. responses = await asyncio.gather(*request_tasks)
  494. responses = list(filter(lambda x: x is not None, responses))
  495. if len(responses) > 0:
  496. lowest_version = min(
  497. responses,
  498. key=lambda x: tuple(
  499. map(int, re.sub(r"^v|-.*", "", x["version"]).split("."))
  500. ),
  501. )
  502. return {"version": lowest_version["version"]}
  503. else:
  504. raise HTTPException(
  505. status_code=500,
  506. detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
  507. )
  508. else:
  509. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  510. r = None
  511. try:
  512. r = requests.request(method="GET", url=f"{url}/api/version")
  513. r.raise_for_status()
  514. return r.json()
  515. except Exception as e:
  516. log.exception(e)
  517. detail = None
  518. if r is not None:
  519. try:
  520. res = r.json()
  521. if "error" in res:
  522. detail = f"Ollama: {res['error']}"
  523. except Exception:
  524. detail = f"Ollama: {e}"
  525. raise HTTPException(
  526. status_code=r.status_code if r else 500,
  527. detail=detail if detail else "Open WebUI: Server Connection Error",
  528. )
  529. else:
  530. return {"version": False}
  531. class ModelNameForm(BaseModel):
  532. name: str
  533. @router.post("/api/unload")
  534. async def unload_model(
  535. request: Request,
  536. form_data: ModelNameForm,
  537. user=Depends(get_admin_user),
  538. ):
  539. model_name = form_data.name
  540. if not model_name:
  541. raise HTTPException(
  542. status_code=400, detail="Missing 'name' of model to unload."
  543. )
  544. # Refresh/load models if needed, get mapping from name to URLs
  545. await get_all_models(request, user=user)
  546. models = request.app.state.OLLAMA_MODELS
  547. # Canonicalize model name (if not supplied with version)
  548. if ":" not in model_name:
  549. model_name = f"{model_name}:latest"
  550. if model_name not in models:
  551. raise HTTPException(
  552. status_code=400, detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model_name)
  553. )
  554. url_indices = models[model_name]["urls"]
  555. # Send unload to ALL url_indices
  556. results = []
  557. errors = []
  558. for idx in url_indices:
  559. url = request.app.state.config.OLLAMA_BASE_URLS[idx]
  560. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  561. str(idx), request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
  562. )
  563. key = get_api_key(idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  564. prefix_id = api_config.get("prefix_id", None)
  565. if prefix_id and model_name.startswith(f"{prefix_id}."):
  566. model_name = model_name[len(f"{prefix_id}.") :]
  567. payload = {"model": model_name, "keep_alive": 0, "prompt": ""}
  568. try:
  569. res = await send_post_request(
  570. url=f"{url}/api/generate",
  571. payload=json.dumps(payload),
  572. stream=False,
  573. key=key,
  574. user=user,
  575. )
  576. results.append({"url_idx": idx, "success": True, "response": res})
  577. except Exception as e:
  578. log.exception(f"Failed to unload model on node {idx}: {e}")
  579. errors.append({"url_idx": idx, "success": False, "error": str(e)})
  580. if len(errors) > 0:
  581. raise HTTPException(
  582. status_code=500,
  583. detail=f"Failed to unload model on {len(errors)} nodes: {errors}",
  584. )
  585. return {"status": True}
  586. @router.post("/api/pull")
  587. @router.post("/api/pull/{url_idx}")
  588. async def pull_model(
  589. request: Request,
  590. form_data: ModelNameForm,
  591. url_idx: int = 0,
  592. user=Depends(get_admin_user),
  593. ):
  594. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  595. log.info(f"url: {url}")
  596. # Admin should be able to pull models from any source
  597. payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
  598. return await send_post_request(
  599. url=f"{url}/api/pull",
  600. payload=json.dumps(payload),
  601. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  602. user=user,
  603. )
  604. class PushModelForm(BaseModel):
  605. name: str
  606. insecure: Optional[bool] = None
  607. stream: Optional[bool] = None
  608. @router.delete("/api/push")
  609. @router.delete("/api/push/{url_idx}")
  610. async def push_model(
  611. request: Request,
  612. form_data: PushModelForm,
  613. url_idx: Optional[int] = None,
  614. user=Depends(get_admin_user),
  615. ):
  616. if url_idx is None:
  617. await get_all_models(request, user=user)
  618. models = request.app.state.OLLAMA_MODELS
  619. if form_data.name in models:
  620. url_idx = models[form_data.name]["urls"][0]
  621. else:
  622. raise HTTPException(
  623. status_code=400,
  624. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
  625. )
  626. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  627. log.debug(f"url: {url}")
  628. return await send_post_request(
  629. url=f"{url}/api/push",
  630. payload=form_data.model_dump_json(exclude_none=True).encode(),
  631. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  632. user=user,
  633. )
  634. class CreateModelForm(BaseModel):
  635. model: Optional[str] = None
  636. stream: Optional[bool] = None
  637. path: Optional[str] = None
  638. model_config = ConfigDict(extra="allow")
  639. @router.post("/api/create")
  640. @router.post("/api/create/{url_idx}")
  641. async def create_model(
  642. request: Request,
  643. form_data: CreateModelForm,
  644. url_idx: int = 0,
  645. user=Depends(get_admin_user),
  646. ):
  647. log.debug(f"form_data: {form_data}")
  648. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  649. return await send_post_request(
  650. url=f"{url}/api/create",
  651. payload=form_data.model_dump_json(exclude_none=True).encode(),
  652. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  653. user=user,
  654. )
  655. class CopyModelForm(BaseModel):
  656. source: str
  657. destination: str
  658. @router.post("/api/copy")
  659. @router.post("/api/copy/{url_idx}")
  660. async def copy_model(
  661. request: Request,
  662. form_data: CopyModelForm,
  663. url_idx: Optional[int] = None,
  664. user=Depends(get_admin_user),
  665. ):
  666. if url_idx is None:
  667. await get_all_models(request, user=user)
  668. models = request.app.state.OLLAMA_MODELS
  669. if form_data.source in models:
  670. url_idx = models[form_data.source]["urls"][0]
  671. else:
  672. raise HTTPException(
  673. status_code=400,
  674. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source),
  675. )
  676. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  677. key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  678. try:
  679. r = requests.request(
  680. method="POST",
  681. url=f"{url}/api/copy",
  682. headers={
  683. "Content-Type": "application/json",
  684. **({"Authorization": f"Bearer {key}"} if key else {}),
  685. **(
  686. {
  687. "X-OpenWebUI-User-Name": user.name,
  688. "X-OpenWebUI-User-Id": user.id,
  689. "X-OpenWebUI-User-Email": user.email,
  690. "X-OpenWebUI-User-Role": user.role,
  691. }
  692. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  693. else {}
  694. ),
  695. },
  696. data=form_data.model_dump_json(exclude_none=True).encode(),
  697. )
  698. r.raise_for_status()
  699. log.debug(f"r.text: {r.text}")
  700. return True
  701. except Exception as e:
  702. log.exception(e)
  703. detail = None
  704. if r is not None:
  705. try:
  706. res = r.json()
  707. if "error" in res:
  708. detail = f"Ollama: {res['error']}"
  709. except Exception:
  710. detail = f"Ollama: {e}"
  711. raise HTTPException(
  712. status_code=r.status_code if r else 500,
  713. detail=detail if detail else "Open WebUI: Server Connection Error",
  714. )
  715. @router.delete("/api/delete")
  716. @router.delete("/api/delete/{url_idx}")
  717. async def delete_model(
  718. request: Request,
  719. form_data: ModelNameForm,
  720. url_idx: Optional[int] = None,
  721. user=Depends(get_admin_user),
  722. ):
  723. if url_idx is None:
  724. await get_all_models(request, user=user)
  725. models = request.app.state.OLLAMA_MODELS
  726. if form_data.name in models:
  727. url_idx = models[form_data.name]["urls"][0]
  728. else:
  729. raise HTTPException(
  730. status_code=400,
  731. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
  732. )
  733. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  734. key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  735. try:
  736. r = requests.request(
  737. method="DELETE",
  738. url=f"{url}/api/delete",
  739. data=form_data.model_dump_json(exclude_none=True).encode(),
  740. headers={
  741. "Content-Type": "application/json",
  742. **({"Authorization": f"Bearer {key}"} if key else {}),
  743. **(
  744. {
  745. "X-OpenWebUI-User-Name": user.name,
  746. "X-OpenWebUI-User-Id": user.id,
  747. "X-OpenWebUI-User-Email": user.email,
  748. "X-OpenWebUI-User-Role": user.role,
  749. }
  750. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  751. else {}
  752. ),
  753. },
  754. )
  755. r.raise_for_status()
  756. log.debug(f"r.text: {r.text}")
  757. return True
  758. except Exception as e:
  759. log.exception(e)
  760. detail = None
  761. if r is not None:
  762. try:
  763. res = r.json()
  764. if "error" in res:
  765. detail = f"Ollama: {res['error']}"
  766. except Exception:
  767. detail = f"Ollama: {e}"
  768. raise HTTPException(
  769. status_code=r.status_code if r else 500,
  770. detail=detail if detail else "Open WebUI: Server Connection Error",
  771. )
  772. @router.post("/api/show")
  773. async def show_model_info(
  774. request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
  775. ):
  776. await get_all_models(request, user=user)
  777. models = request.app.state.OLLAMA_MODELS
  778. if form_data.name not in models:
  779. raise HTTPException(
  780. status_code=400,
  781. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
  782. )
  783. url_idx = random.choice(models[form_data.name]["urls"])
  784. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  785. key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  786. try:
  787. r = requests.request(
  788. method="POST",
  789. url=f"{url}/api/show",
  790. headers={
  791. "Content-Type": "application/json",
  792. **({"Authorization": f"Bearer {key}"} if key else {}),
  793. **(
  794. {
  795. "X-OpenWebUI-User-Name": user.name,
  796. "X-OpenWebUI-User-Id": user.id,
  797. "X-OpenWebUI-User-Email": user.email,
  798. "X-OpenWebUI-User-Role": user.role,
  799. }
  800. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  801. else {}
  802. ),
  803. },
  804. data=form_data.model_dump_json(exclude_none=True).encode(),
  805. )
  806. r.raise_for_status()
  807. return r.json()
  808. except Exception as e:
  809. log.exception(e)
  810. detail = None
  811. if r is not None:
  812. try:
  813. res = r.json()
  814. if "error" in res:
  815. detail = f"Ollama: {res['error']}"
  816. except Exception:
  817. detail = f"Ollama: {e}"
  818. raise HTTPException(
  819. status_code=r.status_code if r else 500,
  820. detail=detail if detail else "Open WebUI: Server Connection Error",
  821. )
  822. class GenerateEmbedForm(BaseModel):
  823. model: str
  824. input: list[str] | str
  825. truncate: Optional[bool] = None
  826. options: Optional[dict] = None
  827. keep_alive: Optional[Union[int, str]] = None
  828. @router.post("/api/embed")
  829. @router.post("/api/embed/{url_idx}")
  830. async def embed(
  831. request: Request,
  832. form_data: GenerateEmbedForm,
  833. url_idx: Optional[int] = None,
  834. user=Depends(get_verified_user),
  835. ):
  836. log.info(f"generate_ollama_batch_embeddings {form_data}")
  837. if url_idx is None:
  838. await get_all_models(request, user=user)
  839. models = request.app.state.OLLAMA_MODELS
  840. model = form_data.model
  841. if ":" not in model:
  842. model = f"{model}:latest"
  843. if model in models:
  844. url_idx = random.choice(models[model]["urls"])
  845. else:
  846. raise HTTPException(
  847. status_code=400,
  848. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
  849. )
  850. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  851. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  852. str(url_idx),
  853. request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
  854. )
  855. key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  856. prefix_id = api_config.get("prefix_id", None)
  857. if prefix_id:
  858. form_data.model = form_data.model.replace(f"{prefix_id}.", "")
  859. try:
  860. r = requests.request(
  861. method="POST",
  862. url=f"{url}/api/embed",
  863. headers={
  864. "Content-Type": "application/json",
  865. **({"Authorization": f"Bearer {key}"} if key else {}),
  866. **(
  867. {
  868. "X-OpenWebUI-User-Name": user.name,
  869. "X-OpenWebUI-User-Id": user.id,
  870. "X-OpenWebUI-User-Email": user.email,
  871. "X-OpenWebUI-User-Role": user.role,
  872. }
  873. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  874. else {}
  875. ),
  876. },
  877. data=form_data.model_dump_json(exclude_none=True).encode(),
  878. )
  879. r.raise_for_status()
  880. data = r.json()
  881. return data
  882. except Exception as e:
  883. log.exception(e)
  884. detail = None
  885. if r is not None:
  886. try:
  887. res = r.json()
  888. if "error" in res:
  889. detail = f"Ollama: {res['error']}"
  890. except Exception:
  891. detail = f"Ollama: {e}"
  892. raise HTTPException(
  893. status_code=r.status_code if r else 500,
  894. detail=detail if detail else "Open WebUI: Server Connection Error",
  895. )
  896. class GenerateEmbeddingsForm(BaseModel):
  897. model: str
  898. prompt: str
  899. options: Optional[dict] = None
  900. keep_alive: Optional[Union[int, str]] = None
  901. @router.post("/api/embeddings")
  902. @router.post("/api/embeddings/{url_idx}")
  903. async def embeddings(
  904. request: Request,
  905. form_data: GenerateEmbeddingsForm,
  906. url_idx: Optional[int] = None,
  907. user=Depends(get_verified_user),
  908. ):
  909. log.info(f"generate_ollama_embeddings {form_data}")
  910. if url_idx is None:
  911. await get_all_models(request, user=user)
  912. models = request.app.state.OLLAMA_MODELS
  913. model = form_data.model
  914. if ":" not in model:
  915. model = f"{model}:latest"
  916. if model in models:
  917. url_idx = random.choice(models[model]["urls"])
  918. else:
  919. raise HTTPException(
  920. status_code=400,
  921. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
  922. )
  923. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  924. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  925. str(url_idx),
  926. request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
  927. )
  928. key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  929. prefix_id = api_config.get("prefix_id", None)
  930. if prefix_id:
  931. form_data.model = form_data.model.replace(f"{prefix_id}.", "")
  932. try:
  933. r = requests.request(
  934. method="POST",
  935. url=f"{url}/api/embeddings",
  936. headers={
  937. "Content-Type": "application/json",
  938. **({"Authorization": f"Bearer {key}"} if key else {}),
  939. **(
  940. {
  941. "X-OpenWebUI-User-Name": user.name,
  942. "X-OpenWebUI-User-Id": user.id,
  943. "X-OpenWebUI-User-Email": user.email,
  944. "X-OpenWebUI-User-Role": user.role,
  945. }
  946. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  947. else {}
  948. ),
  949. },
  950. data=form_data.model_dump_json(exclude_none=True).encode(),
  951. )
  952. r.raise_for_status()
  953. data = r.json()
  954. return data
  955. except Exception as e:
  956. log.exception(e)
  957. detail = None
  958. if r is not None:
  959. try:
  960. res = r.json()
  961. if "error" in res:
  962. detail = f"Ollama: {res['error']}"
  963. except Exception:
  964. detail = f"Ollama: {e}"
  965. raise HTTPException(
  966. status_code=r.status_code if r else 500,
  967. detail=detail if detail else "Open WebUI: Server Connection Error",
  968. )
  969. class GenerateCompletionForm(BaseModel):
  970. model: str
  971. prompt: str
  972. suffix: Optional[str] = None
  973. images: Optional[list[str]] = None
  974. format: Optional[Union[dict, str]] = None
  975. options: Optional[dict] = None
  976. system: Optional[str] = None
  977. template: Optional[str] = None
  978. context: Optional[list[int]] = None
  979. stream: Optional[bool] = True
  980. raw: Optional[bool] = None
  981. keep_alive: Optional[Union[int, str]] = None
  982. @router.post("/api/generate")
  983. @router.post("/api/generate/{url_idx}")
  984. async def generate_completion(
  985. request: Request,
  986. form_data: GenerateCompletionForm,
  987. url_idx: Optional[int] = None,
  988. user=Depends(get_verified_user),
  989. ):
  990. if url_idx is None:
  991. await get_all_models(request, user=user)
  992. models = request.app.state.OLLAMA_MODELS
  993. model = form_data.model
  994. if ":" not in model:
  995. model = f"{model}:latest"
  996. if model in models:
  997. url_idx = random.choice(models[model]["urls"])
  998. else:
  999. raise HTTPException(
  1000. status_code=400,
  1001. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
  1002. )
  1003. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  1004. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  1005. str(url_idx),
  1006. request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
  1007. )
  1008. prefix_id = api_config.get("prefix_id", None)
  1009. if prefix_id:
  1010. form_data.model = form_data.model.replace(f"{prefix_id}.", "")
  1011. return await send_post_request(
  1012. url=f"{url}/api/generate",
  1013. payload=form_data.model_dump_json(exclude_none=True).encode(),
  1014. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  1015. user=user,
  1016. )
  1017. class ChatMessage(BaseModel):
  1018. role: str
  1019. content: Optional[str] = None
  1020. tool_calls: Optional[list[dict]] = None
  1021. images: Optional[list[str]] = None
  1022. @validator("content", pre=True)
  1023. @classmethod
  1024. def check_at_least_one_field(cls, field_value, values, **kwargs):
  1025. # Raise an error if both 'content' and 'tool_calls' are None
  1026. if field_value is None and (
  1027. "tool_calls" not in values or values["tool_calls"] is None
  1028. ):
  1029. raise ValueError(
  1030. "At least one of 'content' or 'tool_calls' must be provided"
  1031. )
  1032. return field_value
  1033. class GenerateChatCompletionForm(BaseModel):
  1034. model: str
  1035. messages: list[ChatMessage]
  1036. format: Optional[Union[dict, str]] = None
  1037. options: Optional[dict] = None
  1038. template: Optional[str] = None
  1039. stream: Optional[bool] = True
  1040. keep_alive: Optional[Union[int, str]] = None
  1041. tools: Optional[list[dict]] = None
  1042. model_config = ConfigDict(
  1043. extra="allow",
  1044. )
  1045. async def get_ollama_url(request: Request, model: str, url_idx: Optional[int] = None):
  1046. if url_idx is None:
  1047. models = request.app.state.OLLAMA_MODELS
  1048. if model not in models:
  1049. raise HTTPException(
  1050. status_code=400,
  1051. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
  1052. )
  1053. url_idx = random.choice(models[model].get("urls", []))
  1054. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  1055. return url, url_idx
  1056. @router.post("/api/chat")
  1057. @router.post("/api/chat/{url_idx}")
  1058. async def generate_chat_completion(
  1059. request: Request,
  1060. form_data: dict,
  1061. url_idx: Optional[int] = None,
  1062. user=Depends(get_verified_user),
  1063. bypass_filter: Optional[bool] = False,
  1064. ):
  1065. if BYPASS_MODEL_ACCESS_CONTROL:
  1066. bypass_filter = True
  1067. metadata = form_data.pop("metadata", None)
  1068. try:
  1069. form_data = GenerateChatCompletionForm(**form_data)
  1070. except Exception as e:
  1071. log.exception(e)
  1072. raise HTTPException(
  1073. status_code=400,
  1074. detail=str(e),
  1075. )
  1076. if isinstance(form_data, BaseModel):
  1077. payload = {**form_data.model_dump(exclude_none=True)}
  1078. if "metadata" in payload:
  1079. del payload["metadata"]
  1080. model_id = payload["model"]
  1081. model_info = Models.get_model_by_id(model_id)
  1082. if model_info:
  1083. if model_info.base_model_id:
  1084. payload["model"] = model_info.base_model_id
  1085. params = model_info.params.model_dump()
  1086. if params:
  1087. system = params.pop("system", None)
  1088. payload = apply_model_params_to_body_ollama(params, payload)
  1089. payload = apply_model_system_prompt_to_body(system, payload, metadata, user)
  1090. # Check if user has access to the model
  1091. if not bypass_filter and user.role == "user":
  1092. if not (
  1093. user.id == model_info.user_id
  1094. or has_access(
  1095. user.id, type="read", access_control=model_info.access_control
  1096. )
  1097. ):
  1098. raise HTTPException(
  1099. status_code=403,
  1100. detail="Model not found",
  1101. )
  1102. elif not bypass_filter:
  1103. if user.role != "admin":
  1104. raise HTTPException(
  1105. status_code=403,
  1106. detail="Model not found",
  1107. )
  1108. if ":" not in payload["model"]:
  1109. payload["model"] = f"{payload['model']}:latest"
  1110. url, url_idx = await get_ollama_url(request, payload["model"], url_idx)
  1111. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  1112. str(url_idx),
  1113. request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
  1114. )
  1115. prefix_id = api_config.get("prefix_id", None)
  1116. if prefix_id:
  1117. payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
  1118. return await send_post_request(
  1119. url=f"{url}/api/chat",
  1120. payload=json.dumps(payload),
  1121. stream=form_data.stream,
  1122. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  1123. content_type="application/x-ndjson",
  1124. user=user,
  1125. )
  1126. # TODO: we should update this part once Ollama supports other types
  1127. class OpenAIChatMessageContent(BaseModel):
  1128. type: str
  1129. model_config = ConfigDict(extra="allow")
  1130. class OpenAIChatMessage(BaseModel):
  1131. role: str
  1132. content: Union[Optional[str], list[OpenAIChatMessageContent]]
  1133. model_config = ConfigDict(extra="allow")
  1134. class OpenAIChatCompletionForm(BaseModel):
  1135. model: str
  1136. messages: list[OpenAIChatMessage]
  1137. model_config = ConfigDict(extra="allow")
  1138. class OpenAICompletionForm(BaseModel):
  1139. model: str
  1140. prompt: str
  1141. model_config = ConfigDict(extra="allow")
  1142. @router.post("/v1/completions")
  1143. @router.post("/v1/completions/{url_idx}")
  1144. async def generate_openai_completion(
  1145. request: Request,
  1146. form_data: dict,
  1147. url_idx: Optional[int] = None,
  1148. user=Depends(get_verified_user),
  1149. ):
  1150. try:
  1151. form_data = OpenAICompletionForm(**form_data)
  1152. except Exception as e:
  1153. log.exception(e)
  1154. raise HTTPException(
  1155. status_code=400,
  1156. detail=str(e),
  1157. )
  1158. payload = {**form_data.model_dump(exclude_none=True, exclude=["metadata"])}
  1159. if "metadata" in payload:
  1160. del payload["metadata"]
  1161. model_id = form_data.model
  1162. if ":" not in model_id:
  1163. model_id = f"{model_id}:latest"
  1164. model_info = Models.get_model_by_id(model_id)
  1165. if model_info:
  1166. if model_info.base_model_id:
  1167. payload["model"] = model_info.base_model_id
  1168. params = model_info.params.model_dump()
  1169. if params:
  1170. payload = apply_model_params_to_body_openai(params, payload)
  1171. # Check if user has access to the model
  1172. if user.role == "user":
  1173. if not (
  1174. user.id == model_info.user_id
  1175. or has_access(
  1176. user.id, type="read", access_control=model_info.access_control
  1177. )
  1178. ):
  1179. raise HTTPException(
  1180. status_code=403,
  1181. detail="Model not found",
  1182. )
  1183. else:
  1184. if user.role != "admin":
  1185. raise HTTPException(
  1186. status_code=403,
  1187. detail="Model not found",
  1188. )
  1189. if ":" not in payload["model"]:
  1190. payload["model"] = f"{payload['model']}:latest"
  1191. url, url_idx = await get_ollama_url(request, payload["model"], url_idx)
  1192. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  1193. str(url_idx),
  1194. request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
  1195. )
  1196. prefix_id = api_config.get("prefix_id", None)
  1197. if prefix_id:
  1198. payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
  1199. return await send_post_request(
  1200. url=f"{url}/v1/completions",
  1201. payload=json.dumps(payload),
  1202. stream=payload.get("stream", False),
  1203. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  1204. user=user,
  1205. )
  1206. @router.post("/v1/chat/completions")
  1207. @router.post("/v1/chat/completions/{url_idx}")
  1208. async def generate_openai_chat_completion(
  1209. request: Request,
  1210. form_data: dict,
  1211. url_idx: Optional[int] = None,
  1212. user=Depends(get_verified_user),
  1213. ):
  1214. metadata = form_data.pop("metadata", None)
  1215. try:
  1216. completion_form = OpenAIChatCompletionForm(**form_data)
  1217. except Exception as e:
  1218. log.exception(e)
  1219. raise HTTPException(
  1220. status_code=400,
  1221. detail=str(e),
  1222. )
  1223. payload = {**completion_form.model_dump(exclude_none=True, exclude=["metadata"])}
  1224. if "metadata" in payload:
  1225. del payload["metadata"]
  1226. model_id = completion_form.model
  1227. if ":" not in model_id:
  1228. model_id = f"{model_id}:latest"
  1229. model_info = Models.get_model_by_id(model_id)
  1230. if model_info:
  1231. if model_info.base_model_id:
  1232. payload["model"] = model_info.base_model_id
  1233. params = model_info.params.model_dump()
  1234. if params:
  1235. system = params.pop("system", None)
  1236. payload = apply_model_params_to_body_openai(params, payload)
  1237. payload = apply_model_system_prompt_to_body(system, payload, metadata, user)
  1238. # Check if user has access to the model
  1239. if user.role == "user":
  1240. if not (
  1241. user.id == model_info.user_id
  1242. or has_access(
  1243. user.id, type="read", access_control=model_info.access_control
  1244. )
  1245. ):
  1246. raise HTTPException(
  1247. status_code=403,
  1248. detail="Model not found",
  1249. )
  1250. else:
  1251. if user.role != "admin":
  1252. raise HTTPException(
  1253. status_code=403,
  1254. detail="Model not found",
  1255. )
  1256. if ":" not in payload["model"]:
  1257. payload["model"] = f"{payload['model']}:latest"
  1258. url, url_idx = await get_ollama_url(request, payload["model"], url_idx)
  1259. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  1260. str(url_idx),
  1261. request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
  1262. )
  1263. prefix_id = api_config.get("prefix_id", None)
  1264. if prefix_id:
  1265. payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
  1266. return await send_post_request(
  1267. url=f"{url}/v1/chat/completions",
  1268. payload=json.dumps(payload),
  1269. stream=payload.get("stream", False),
  1270. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  1271. user=user,
  1272. )
  1273. @router.get("/v1/models")
  1274. @router.get("/v1/models/{url_idx}")
  1275. async def get_openai_models(
  1276. request: Request,
  1277. url_idx: Optional[int] = None,
  1278. user=Depends(get_verified_user),
  1279. ):
  1280. models = []
  1281. if url_idx is None:
  1282. model_list = await get_all_models(request, user=user)
  1283. models = [
  1284. {
  1285. "id": model["model"],
  1286. "object": "model",
  1287. "created": int(time.time()),
  1288. "owned_by": "openai",
  1289. }
  1290. for model in model_list["models"]
  1291. ]
  1292. else:
  1293. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  1294. try:
  1295. r = requests.request(method="GET", url=f"{url}/api/tags")
  1296. r.raise_for_status()
  1297. model_list = r.json()
  1298. models = [
  1299. {
  1300. "id": model["model"],
  1301. "object": "model",
  1302. "created": int(time.time()),
  1303. "owned_by": "openai",
  1304. }
  1305. for model in models["models"]
  1306. ]
  1307. except Exception as e:
  1308. log.exception(e)
  1309. error_detail = "Open WebUI: Server Connection Error"
  1310. if r is not None:
  1311. try:
  1312. res = r.json()
  1313. if "error" in res:
  1314. error_detail = f"Ollama: {res['error']}"
  1315. except Exception:
  1316. error_detail = f"Ollama: {e}"
  1317. raise HTTPException(
  1318. status_code=r.status_code if r else 500,
  1319. detail=error_detail,
  1320. )
  1321. if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
  1322. # Filter models based on user access control
  1323. filtered_models = []
  1324. for model in models:
  1325. model_info = Models.get_model_by_id(model["id"])
  1326. if model_info:
  1327. if user.id == model_info.user_id or has_access(
  1328. user.id, type="read", access_control=model_info.access_control
  1329. ):
  1330. filtered_models.append(model)
  1331. models = filtered_models
  1332. return {
  1333. "data": models,
  1334. "object": "list",
  1335. }
  1336. class UrlForm(BaseModel):
  1337. url: str
  1338. class UploadBlobForm(BaseModel):
  1339. filename: str
  1340. def parse_huggingface_url(hf_url):
  1341. try:
  1342. # Parse the URL
  1343. parsed_url = urlparse(hf_url)
  1344. # Get the path and split it into components
  1345. path_components = parsed_url.path.split("/")
  1346. # Extract the desired output
  1347. model_file = path_components[-1]
  1348. return model_file
  1349. except ValueError:
  1350. return None
  1351. async def download_file_stream(
  1352. ollama_url, file_url, file_path, file_name, chunk_size=1024 * 1024
  1353. ):
  1354. done = False
  1355. if os.path.exists(file_path):
  1356. current_size = os.path.getsize(file_path)
  1357. else:
  1358. current_size = 0
  1359. headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {}
  1360. timeout = aiohttp.ClientTimeout(total=600) # Set the timeout
  1361. async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
  1362. async with session.get(
  1363. file_url, headers=headers, ssl=AIOHTTP_CLIENT_SESSION_SSL
  1364. ) as response:
  1365. total_size = int(response.headers.get("content-length", 0)) + current_size
  1366. with open(file_path, "ab+") as file:
  1367. async for data in response.content.iter_chunked(chunk_size):
  1368. current_size += len(data)
  1369. file.write(data)
  1370. done = current_size == total_size
  1371. progress = round((current_size / total_size) * 100, 2)
  1372. yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n'
  1373. if done:
  1374. file.seek(0)
  1375. chunk_size = 1024 * 1024 * 2
  1376. hashed = calculate_sha256(file, chunk_size)
  1377. file.seek(0)
  1378. url = f"{ollama_url}/api/blobs/sha256:{hashed}"
  1379. response = requests.post(url, data=file)
  1380. if response.ok:
  1381. res = {
  1382. "done": done,
  1383. "blob": f"sha256:{hashed}",
  1384. "name": file_name,
  1385. }
  1386. os.remove(file_path)
  1387. yield f"data: {json.dumps(res)}\n\n"
  1388. else:
  1389. raise "Ollama: Could not create blob, Please try again."
  1390. # url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
  1391. @router.post("/models/download")
  1392. @router.post("/models/download/{url_idx}")
  1393. async def download_model(
  1394. request: Request,
  1395. form_data: UrlForm,
  1396. url_idx: Optional[int] = None,
  1397. user=Depends(get_admin_user),
  1398. ):
  1399. allowed_hosts = ["https://huggingface.co/", "https://github.com/"]
  1400. if not any(form_data.url.startswith(host) for host in allowed_hosts):
  1401. raise HTTPException(
  1402. status_code=400,
  1403. detail="Invalid file_url. Only URLs from allowed hosts are permitted.",
  1404. )
  1405. if url_idx is None:
  1406. url_idx = 0
  1407. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  1408. file_name = parse_huggingface_url(form_data.url)
  1409. if file_name:
  1410. file_path = f"{UPLOAD_DIR}/{file_name}"
  1411. return StreamingResponse(
  1412. download_file_stream(url, form_data.url, file_path, file_name),
  1413. )
  1414. else:
  1415. return None
  1416. # TODO: Progress bar does not reflect size & duration of upload.
  1417. @router.post("/models/upload")
  1418. @router.post("/models/upload/{url_idx}")
  1419. async def upload_model(
  1420. request: Request,
  1421. file: UploadFile = File(...),
  1422. url_idx: Optional[int] = None,
  1423. user=Depends(get_admin_user),
  1424. ):
  1425. if url_idx is None:
  1426. url_idx = 0
  1427. ollama_url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  1428. filename = os.path.basename(file.filename)
  1429. file_path = os.path.join(UPLOAD_DIR, filename)
  1430. os.makedirs(UPLOAD_DIR, exist_ok=True)
  1431. # --- P1: save file locally ---
  1432. chunk_size = 1024 * 1024 * 2 # 2 MB chunks
  1433. with open(file_path, "wb") as out_f:
  1434. while True:
  1435. chunk = file.file.read(chunk_size)
  1436. # log.info(f"Chunk: {str(chunk)}") # DEBUG
  1437. if not chunk:
  1438. break
  1439. out_f.write(chunk)
  1440. async def file_process_stream():
  1441. nonlocal ollama_url
  1442. total_size = os.path.getsize(file_path)
  1443. log.info(f"Total Model Size: {str(total_size)}") # DEBUG
  1444. # --- P2: SSE progress + calculate sha256 hash ---
  1445. file_hash = calculate_sha256(file_path, chunk_size)
  1446. log.info(f"Model Hash: {str(file_hash)}") # DEBUG
  1447. try:
  1448. with open(file_path, "rb") as f:
  1449. bytes_read = 0
  1450. while chunk := f.read(chunk_size):
  1451. bytes_read += len(chunk)
  1452. progress = round(bytes_read / total_size * 100, 2)
  1453. data_msg = {
  1454. "progress": progress,
  1455. "total": total_size,
  1456. "completed": bytes_read,
  1457. }
  1458. yield f"data: {json.dumps(data_msg)}\n\n"
  1459. # --- P3: Upload to ollama /api/blobs ---
  1460. with open(file_path, "rb") as f:
  1461. url = f"{ollama_url}/api/blobs/sha256:{file_hash}"
  1462. response = requests.post(url, data=f)
  1463. if response.ok:
  1464. log.info(f"Uploaded to /api/blobs") # DEBUG
  1465. # Remove local file
  1466. os.remove(file_path)
  1467. # Create model in ollama
  1468. model_name, ext = os.path.splitext(filename)
  1469. log.info(f"Created Model: {model_name}") # DEBUG
  1470. create_payload = {
  1471. "model": model_name,
  1472. # Reference the file by its original name => the uploaded blob's digest
  1473. "files": {filename: f"sha256:{file_hash}"},
  1474. }
  1475. log.info(f"Model Payload: {create_payload}") # DEBUG
  1476. # Call ollama /api/create
  1477. # https://github.com/ollama/ollama/blob/main/docs/api.md#create-a-model
  1478. create_resp = requests.post(
  1479. url=f"{ollama_url}/api/create",
  1480. headers={"Content-Type": "application/json"},
  1481. data=json.dumps(create_payload),
  1482. )
  1483. if create_resp.ok:
  1484. log.info(f"API SUCCESS!") # DEBUG
  1485. done_msg = {
  1486. "done": True,
  1487. "blob": f"sha256:{file_hash}",
  1488. "name": filename,
  1489. "model_created": model_name,
  1490. }
  1491. yield f"data: {json.dumps(done_msg)}\n\n"
  1492. else:
  1493. raise Exception(
  1494. f"Failed to create model in Ollama. {create_resp.text}"
  1495. )
  1496. else:
  1497. raise Exception("Ollama: Could not create blob, Please try again.")
  1498. except Exception as e:
  1499. res = {"error": str(e)}
  1500. yield f"data: {json.dumps(res)}\n\n"
  1501. return StreamingResponse(file_process_stream(), media_type="text/event-stream")