1
0

ollama.py 58 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855
  1. # TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
  2. # Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
  3. # least connections, or least response time for better resource utilization and performance optimization.
  4. import asyncio
  5. import json
  6. import logging
  7. import os
  8. import random
  9. import re
  10. import time
  11. from datetime import datetime
  12. from typing import Optional, Union
  13. from urllib.parse import urlparse
  14. import aiohttp
  15. from aiocache import cached
  16. import requests
  17. from urllib.parse import quote
  18. from open_webui.models.chats import Chats
  19. from open_webui.models.users import UserModel
  20. from open_webui.env import (
  21. ENABLE_FORWARD_USER_INFO_HEADERS,
  22. )
  23. from fastapi import (
  24. Depends,
  25. FastAPI,
  26. File,
  27. HTTPException,
  28. Request,
  29. UploadFile,
  30. APIRouter,
  31. )
  32. from fastapi.middleware.cors import CORSMiddleware
  33. from fastapi.responses import StreamingResponse
  34. from pydantic import BaseModel, ConfigDict, validator
  35. from starlette.background import BackgroundTask
  36. from open_webui.models.models import Models
  37. from open_webui.utils.misc import (
  38. calculate_sha256,
  39. )
  40. from open_webui.utils.payload import (
  41. apply_model_params_to_body_ollama,
  42. apply_model_params_to_body_openai,
  43. apply_system_prompt_to_body,
  44. )
  45. from open_webui.utils.auth import get_admin_user, get_verified_user
  46. from open_webui.utils.access_control import has_access
  47. from open_webui.config import (
  48. UPLOAD_DIR,
  49. )
  50. from open_webui.env import (
  51. ENV,
  52. SRC_LOG_LEVELS,
  53. MODELS_CACHE_TTL,
  54. AIOHTTP_CLIENT_SESSION_SSL,
  55. AIOHTTP_CLIENT_TIMEOUT,
  56. AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
  57. BYPASS_MODEL_ACCESS_CONTROL,
  58. )
  59. from open_webui.constants import ERROR_MESSAGES
  60. log = logging.getLogger(__name__)
  61. log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
  62. ##########################################
  63. #
  64. # Utility functions
  65. #
  66. ##########################################
  67. async def send_get_request(url, key=None, user: UserModel = None):
  68. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
  69. try:
  70. async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
  71. async with session.get(
  72. url,
  73. headers={
  74. "Content-Type": "application/json",
  75. **({"Authorization": f"Bearer {key}"} if key else {}),
  76. **(
  77. {
  78. "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
  79. "X-OpenWebUI-User-Id": user.id,
  80. "X-OpenWebUI-User-Email": user.email,
  81. "X-OpenWebUI-User-Role": user.role,
  82. }
  83. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  84. else {}
  85. ),
  86. },
  87. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  88. ) as response:
  89. return await response.json()
  90. except Exception as e:
  91. # Handle connection error here
  92. log.error(f"Connection error: {e}")
  93. return None
  94. async def cleanup_response(
  95. response: Optional[aiohttp.ClientResponse],
  96. session: Optional[aiohttp.ClientSession],
  97. ):
  98. if response:
  99. response.close()
  100. if session:
  101. await session.close()
  102. async def send_post_request(
  103. url: str,
  104. payload: Union[str, bytes],
  105. stream: bool = True,
  106. key: Optional[str] = None,
  107. content_type: Optional[str] = None,
  108. user: UserModel = None,
  109. metadata: Optional[dict] = None,
  110. ):
  111. r = None
  112. try:
  113. session = aiohttp.ClientSession(
  114. trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
  115. )
  116. r = await session.post(
  117. url,
  118. data=payload,
  119. headers={
  120. "Content-Type": "application/json",
  121. **({"Authorization": f"Bearer {key}"} if key else {}),
  122. **(
  123. {
  124. "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
  125. "X-OpenWebUI-User-Id": user.id,
  126. "X-OpenWebUI-User-Email": user.email,
  127. "X-OpenWebUI-User-Role": user.role,
  128. **(
  129. {"X-OpenWebUI-Chat-Id": metadata.get("chat_id")}
  130. if metadata and metadata.get("chat_id")
  131. else {}
  132. ),
  133. }
  134. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  135. else {}
  136. ),
  137. },
  138. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  139. )
  140. if r.ok is False:
  141. try:
  142. res = await r.json()
  143. await cleanup_response(r, session)
  144. if "error" in res:
  145. raise HTTPException(status_code=r.status, detail=res["error"])
  146. except HTTPException as e:
  147. raise e # Re-raise HTTPException to be handled by FastAPI
  148. except Exception as e:
  149. log.error(f"Failed to parse error response: {e}")
  150. raise HTTPException(
  151. status_code=r.status,
  152. detail=f"Open WebUI: Server Connection Error",
  153. )
  154. r.raise_for_status() # Raises an error for bad responses (4xx, 5xx)
  155. if stream:
  156. response_headers = dict(r.headers)
  157. if content_type:
  158. response_headers["Content-Type"] = content_type
  159. return StreamingResponse(
  160. r.content,
  161. status_code=r.status,
  162. headers=response_headers,
  163. background=BackgroundTask(
  164. cleanup_response, response=r, session=session
  165. ),
  166. )
  167. else:
  168. res = await r.json()
  169. return res
  170. except HTTPException as e:
  171. raise e # Re-raise HTTPException to be handled by FastAPI
  172. except Exception as e:
  173. detail = f"Ollama: {e}"
  174. raise HTTPException(
  175. status_code=r.status if r else 500,
  176. detail=detail if e else "Open WebUI: Server Connection Error",
  177. )
  178. finally:
  179. if not stream:
  180. await cleanup_response(r, session)
  181. def get_api_key(idx, url, configs):
  182. parsed_url = urlparse(url)
  183. base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
  184. return configs.get(str(idx), configs.get(base_url, {})).get(
  185. "key", None
  186. ) # Legacy support
  187. ##########################################
  188. #
  189. # API routes
  190. #
  191. ##########################################
  192. router = APIRouter()
  193. @router.head("/")
  194. @router.get("/")
  195. async def get_status():
  196. return {"status": True}
  197. class ConnectionVerificationForm(BaseModel):
  198. url: str
  199. key: Optional[str] = None
  200. @router.post("/verify")
  201. async def verify_connection(
  202. form_data: ConnectionVerificationForm, user=Depends(get_admin_user)
  203. ):
  204. url = form_data.url
  205. key = form_data.key
  206. async with aiohttp.ClientSession(
  207. trust_env=True,
  208. timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
  209. ) as session:
  210. try:
  211. async with session.get(
  212. f"{url}/api/version",
  213. headers={
  214. **({"Authorization": f"Bearer {key}"} if key else {}),
  215. **(
  216. {
  217. "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
  218. "X-OpenWebUI-User-Id": user.id,
  219. "X-OpenWebUI-User-Email": user.email,
  220. "X-OpenWebUI-User-Role": user.role,
  221. }
  222. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  223. else {}
  224. ),
  225. },
  226. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  227. ) as r:
  228. if r.status != 200:
  229. detail = f"HTTP Error: {r.status}"
  230. res = await r.json()
  231. if "error" in res:
  232. detail = f"External Error: {res['error']}"
  233. raise Exception(detail)
  234. data = await r.json()
  235. return data
  236. except aiohttp.ClientError as e:
  237. log.exception(f"Client error: {str(e)}")
  238. raise HTTPException(
  239. status_code=500, detail="Open WebUI: Server Connection Error"
  240. )
  241. except Exception as e:
  242. log.exception(f"Unexpected error: {e}")
  243. error_detail = f"Unexpected error: {str(e)}"
  244. raise HTTPException(status_code=500, detail=error_detail)
  245. @router.get("/config")
  246. async def get_config(request: Request, user=Depends(get_admin_user)):
  247. return {
  248. "ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API,
  249. "OLLAMA_BASE_URLS": request.app.state.config.OLLAMA_BASE_URLS,
  250. "OLLAMA_API_CONFIGS": request.app.state.config.OLLAMA_API_CONFIGS,
  251. }
  252. class OllamaConfigForm(BaseModel):
  253. ENABLE_OLLAMA_API: Optional[bool] = None
  254. OLLAMA_BASE_URLS: list[str]
  255. OLLAMA_API_CONFIGS: dict
  256. @router.post("/config/update")
  257. async def update_config(
  258. request: Request, form_data: OllamaConfigForm, user=Depends(get_admin_user)
  259. ):
  260. request.app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API
  261. request.app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS
  262. request.app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS
  263. # Remove the API configs that are not in the API URLS
  264. keys = list(map(str, range(len(request.app.state.config.OLLAMA_BASE_URLS))))
  265. request.app.state.config.OLLAMA_API_CONFIGS = {
  266. key: value
  267. for key, value in request.app.state.config.OLLAMA_API_CONFIGS.items()
  268. if key in keys
  269. }
  270. return {
  271. "ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API,
  272. "OLLAMA_BASE_URLS": request.app.state.config.OLLAMA_BASE_URLS,
  273. "OLLAMA_API_CONFIGS": request.app.state.config.OLLAMA_API_CONFIGS,
  274. }
  275. def merge_ollama_models_lists(model_lists):
  276. merged_models = {}
  277. for idx, model_list in enumerate(model_lists):
  278. if model_list is not None:
  279. for model in model_list:
  280. id = model.get("model")
  281. if id is not None:
  282. if id not in merged_models:
  283. model["urls"] = [idx]
  284. merged_models[id] = model
  285. else:
  286. merged_models[id]["urls"].append(idx)
  287. return list(merged_models.values())
  288. @cached(
  289. ttl=MODELS_CACHE_TTL,
  290. key=lambda _, user: f"ollama_all_models_{user.id}" if user else "ollama_all_models",
  291. )
  292. async def get_all_models(request: Request, user: UserModel = None):
  293. log.info("get_all_models()")
  294. if request.app.state.config.ENABLE_OLLAMA_API:
  295. request_tasks = []
  296. for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
  297. if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and (
  298. url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support
  299. ):
  300. request_tasks.append(send_get_request(f"{url}/api/tags", user=user))
  301. else:
  302. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  303. str(idx),
  304. request.app.state.config.OLLAMA_API_CONFIGS.get(
  305. url, {}
  306. ), # Legacy support
  307. )
  308. enable = api_config.get("enable", True)
  309. key = api_config.get("key", None)
  310. if enable:
  311. request_tasks.append(
  312. send_get_request(f"{url}/api/tags", key, user=user)
  313. )
  314. else:
  315. request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
  316. responses = await asyncio.gather(*request_tasks)
  317. for idx, response in enumerate(responses):
  318. if response:
  319. url = request.app.state.config.OLLAMA_BASE_URLS[idx]
  320. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  321. str(idx),
  322. request.app.state.config.OLLAMA_API_CONFIGS.get(
  323. url, {}
  324. ), # Legacy support
  325. )
  326. connection_type = api_config.get("connection_type", "local")
  327. prefix_id = api_config.get("prefix_id", None)
  328. tags = api_config.get("tags", [])
  329. model_ids = api_config.get("model_ids", [])
  330. if len(model_ids) != 0 and "models" in response:
  331. response["models"] = list(
  332. filter(
  333. lambda model: model["model"] in model_ids,
  334. response["models"],
  335. )
  336. )
  337. for model in response.get("models", []):
  338. if prefix_id:
  339. model["model"] = f"{prefix_id}.{model['model']}"
  340. if tags:
  341. model["tags"] = tags
  342. if connection_type:
  343. model["connection_type"] = connection_type
  344. models = {
  345. "models": merge_ollama_models_lists(
  346. map(
  347. lambda response: response.get("models", []) if response else None,
  348. responses,
  349. )
  350. )
  351. }
  352. try:
  353. loaded_models = await get_ollama_loaded_models(request, user=user)
  354. expires_map = {
  355. m["model"]: m["expires_at"]
  356. for m in loaded_models["models"]
  357. if "expires_at" in m
  358. }
  359. for m in models["models"]:
  360. if m["model"] in expires_map:
  361. # Parse ISO8601 datetime with offset, get unix timestamp as int
  362. dt = datetime.fromisoformat(expires_map[m["model"]])
  363. m["expires_at"] = int(dt.timestamp())
  364. except Exception as e:
  365. log.debug(f"Failed to get loaded models: {e}")
  366. else:
  367. models = {"models": []}
  368. request.app.state.OLLAMA_MODELS = {
  369. model["model"]: model for model in models["models"]
  370. }
  371. return models
  372. async def get_filtered_models(models, user):
  373. # Filter models based on user access control
  374. filtered_models = []
  375. for model in models.get("models", []):
  376. model_info = Models.get_model_by_id(model["model"])
  377. if model_info:
  378. if user.id == model_info.user_id or has_access(
  379. user.id, type="read", access_control=model_info.access_control
  380. ):
  381. filtered_models.append(model)
  382. return filtered_models
  383. @router.get("/api/tags")
  384. @router.get("/api/tags/{url_idx}")
  385. async def get_ollama_tags(
  386. request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user)
  387. ):
  388. models = []
  389. if url_idx is None:
  390. models = await get_all_models(request, user=user)
  391. else:
  392. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  393. key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  394. r = None
  395. try:
  396. r = requests.request(
  397. method="GET",
  398. url=f"{url}/api/tags",
  399. headers={
  400. **({"Authorization": f"Bearer {key}"} if key else {}),
  401. **(
  402. {
  403. "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
  404. "X-OpenWebUI-User-Id": user.id,
  405. "X-OpenWebUI-User-Email": user.email,
  406. "X-OpenWebUI-User-Role": user.role,
  407. }
  408. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  409. else {}
  410. ),
  411. },
  412. )
  413. r.raise_for_status()
  414. models = r.json()
  415. except Exception as e:
  416. log.exception(e)
  417. detail = None
  418. if r is not None:
  419. try:
  420. res = r.json()
  421. if "error" in res:
  422. detail = f"Ollama: {res['error']}"
  423. except Exception:
  424. detail = f"Ollama: {e}"
  425. raise HTTPException(
  426. status_code=r.status_code if r else 500,
  427. detail=detail if detail else "Open WebUI: Server Connection Error",
  428. )
  429. if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
  430. models["models"] = await get_filtered_models(models, user)
  431. return models
  432. @router.get("/api/ps")
  433. async def get_ollama_loaded_models(request: Request, user=Depends(get_admin_user)):
  434. """
  435. List models that are currently loaded into Ollama memory, and which node they are loaded on.
  436. """
  437. if request.app.state.config.ENABLE_OLLAMA_API:
  438. request_tasks = []
  439. for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
  440. if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and (
  441. url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support
  442. ):
  443. request_tasks.append(send_get_request(f"{url}/api/ps", user=user))
  444. else:
  445. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  446. str(idx),
  447. request.app.state.config.OLLAMA_API_CONFIGS.get(
  448. url, {}
  449. ), # Legacy support
  450. )
  451. enable = api_config.get("enable", True)
  452. key = api_config.get("key", None)
  453. if enable:
  454. request_tasks.append(
  455. send_get_request(f"{url}/api/ps", key, user=user)
  456. )
  457. else:
  458. request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
  459. responses = await asyncio.gather(*request_tasks)
  460. for idx, response in enumerate(responses):
  461. if response:
  462. url = request.app.state.config.OLLAMA_BASE_URLS[idx]
  463. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  464. str(idx),
  465. request.app.state.config.OLLAMA_API_CONFIGS.get(
  466. url, {}
  467. ), # Legacy support
  468. )
  469. prefix_id = api_config.get("prefix_id", None)
  470. for model in response.get("models", []):
  471. if prefix_id:
  472. model["model"] = f"{prefix_id}.{model['model']}"
  473. models = {
  474. "models": merge_ollama_models_lists(
  475. map(
  476. lambda response: response.get("models", []) if response else None,
  477. responses,
  478. )
  479. )
  480. }
  481. else:
  482. models = {"models": []}
  483. return models
  484. @router.get("/api/version")
  485. @router.get("/api/version/{url_idx}")
  486. async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
  487. if request.app.state.config.ENABLE_OLLAMA_API:
  488. if url_idx is None:
  489. # returns lowest version
  490. request_tasks = []
  491. for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
  492. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  493. str(idx),
  494. request.app.state.config.OLLAMA_API_CONFIGS.get(
  495. url, {}
  496. ), # Legacy support
  497. )
  498. enable = api_config.get("enable", True)
  499. key = api_config.get("key", None)
  500. if enable:
  501. request_tasks.append(
  502. send_get_request(
  503. f"{url}/api/version",
  504. key,
  505. )
  506. )
  507. responses = await asyncio.gather(*request_tasks)
  508. responses = list(filter(lambda x: x is not None, responses))
  509. if len(responses) > 0:
  510. lowest_version = min(
  511. responses,
  512. key=lambda x: tuple(
  513. map(int, re.sub(r"^v|-.*", "", x["version"]).split("."))
  514. ),
  515. )
  516. return {"version": lowest_version["version"]}
  517. else:
  518. raise HTTPException(
  519. status_code=500,
  520. detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
  521. )
  522. else:
  523. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  524. r = None
  525. try:
  526. r = requests.request(method="GET", url=f"{url}/api/version")
  527. r.raise_for_status()
  528. return r.json()
  529. except Exception as e:
  530. log.exception(e)
  531. detail = None
  532. if r is not None:
  533. try:
  534. res = r.json()
  535. if "error" in res:
  536. detail = f"Ollama: {res['error']}"
  537. except Exception:
  538. detail = f"Ollama: {e}"
  539. raise HTTPException(
  540. status_code=r.status_code if r else 500,
  541. detail=detail if detail else "Open WebUI: Server Connection Error",
  542. )
  543. else:
  544. return {"version": False}
  545. class ModelNameForm(BaseModel):
  546. model: Optional[str] = None
  547. model_config = ConfigDict(
  548. extra="allow",
  549. )
  550. @router.post("/api/unload")
  551. async def unload_model(
  552. request: Request,
  553. form_data: ModelNameForm,
  554. user=Depends(get_admin_user),
  555. ):
  556. form_data = form_data.model_dump(exclude_none=True)
  557. model_name = form_data.get("model", form_data.get("name"))
  558. if not model_name:
  559. raise HTTPException(
  560. status_code=400, detail="Missing name of the model to unload."
  561. )
  562. # Refresh/load models if needed, get mapping from name to URLs
  563. await get_all_models(request, user=user)
  564. models = request.app.state.OLLAMA_MODELS
  565. # Canonicalize model name (if not supplied with version)
  566. if ":" not in model_name:
  567. model_name = f"{model_name}:latest"
  568. if model_name not in models:
  569. raise HTTPException(
  570. status_code=400, detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model_name)
  571. )
  572. url_indices = models[model_name]["urls"]
  573. # Send unload to ALL url_indices
  574. results = []
  575. errors = []
  576. for idx in url_indices:
  577. url = request.app.state.config.OLLAMA_BASE_URLS[idx]
  578. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  579. str(idx), request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
  580. )
  581. key = get_api_key(idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  582. prefix_id = api_config.get("prefix_id", None)
  583. if prefix_id and model_name.startswith(f"{prefix_id}."):
  584. model_name = model_name[len(f"{prefix_id}.") :]
  585. payload = {"model": model_name, "keep_alive": 0, "prompt": ""}
  586. try:
  587. res = await send_post_request(
  588. url=f"{url}/api/generate",
  589. payload=json.dumps(payload),
  590. stream=False,
  591. key=key,
  592. user=user,
  593. )
  594. results.append({"url_idx": idx, "success": True, "response": res})
  595. except Exception as e:
  596. log.exception(f"Failed to unload model on node {idx}: {e}")
  597. errors.append({"url_idx": idx, "success": False, "error": str(e)})
  598. if len(errors) > 0:
  599. raise HTTPException(
  600. status_code=500,
  601. detail=f"Failed to unload model on {len(errors)} nodes: {errors}",
  602. )
  603. return {"status": True}
  604. @router.post("/api/pull")
  605. @router.post("/api/pull/{url_idx}")
  606. async def pull_model(
  607. request: Request,
  608. form_data: ModelNameForm,
  609. url_idx: int = 0,
  610. user=Depends(get_admin_user),
  611. ):
  612. form_data = form_data.model_dump(exclude_none=True)
  613. form_data["model"] = form_data.get("model", form_data.get("name"))
  614. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  615. log.info(f"url: {url}")
  616. # Admin should be able to pull models from any source
  617. payload = {**form_data, "insecure": True}
  618. return await send_post_request(
  619. url=f"{url}/api/pull",
  620. payload=json.dumps(payload),
  621. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  622. user=user,
  623. )
  624. class PushModelForm(BaseModel):
  625. model: str
  626. insecure: Optional[bool] = None
  627. stream: Optional[bool] = None
  628. @router.delete("/api/push")
  629. @router.delete("/api/push/{url_idx}")
  630. async def push_model(
  631. request: Request,
  632. form_data: PushModelForm,
  633. url_idx: Optional[int] = None,
  634. user=Depends(get_admin_user),
  635. ):
  636. if url_idx is None:
  637. await get_all_models(request, user=user)
  638. models = request.app.state.OLLAMA_MODELS
  639. if form_data.model in models:
  640. url_idx = models[form_data.model]["urls"][0]
  641. else:
  642. raise HTTPException(
  643. status_code=400,
  644. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
  645. )
  646. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  647. log.debug(f"url: {url}")
  648. return await send_post_request(
  649. url=f"{url}/api/push",
  650. payload=form_data.model_dump_json(exclude_none=True).encode(),
  651. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  652. user=user,
  653. )
  654. class CreateModelForm(BaseModel):
  655. model: Optional[str] = None
  656. stream: Optional[bool] = None
  657. path: Optional[str] = None
  658. model_config = ConfigDict(extra="allow")
  659. @router.post("/api/create")
  660. @router.post("/api/create/{url_idx}")
  661. async def create_model(
  662. request: Request,
  663. form_data: CreateModelForm,
  664. url_idx: int = 0,
  665. user=Depends(get_admin_user),
  666. ):
  667. log.debug(f"form_data: {form_data}")
  668. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  669. return await send_post_request(
  670. url=f"{url}/api/create",
  671. payload=form_data.model_dump_json(exclude_none=True).encode(),
  672. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  673. user=user,
  674. )
  675. class CopyModelForm(BaseModel):
  676. source: str
  677. destination: str
  678. @router.post("/api/copy")
  679. @router.post("/api/copy/{url_idx}")
  680. async def copy_model(
  681. request: Request,
  682. form_data: CopyModelForm,
  683. url_idx: Optional[int] = None,
  684. user=Depends(get_admin_user),
  685. ):
  686. if url_idx is None:
  687. await get_all_models(request, user=user)
  688. models = request.app.state.OLLAMA_MODELS
  689. if form_data.source in models:
  690. url_idx = models[form_data.source]["urls"][0]
  691. else:
  692. raise HTTPException(
  693. status_code=400,
  694. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source),
  695. )
  696. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  697. key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  698. try:
  699. r = requests.request(
  700. method="POST",
  701. url=f"{url}/api/copy",
  702. headers={
  703. "Content-Type": "application/json",
  704. **({"Authorization": f"Bearer {key}"} if key else {}),
  705. **(
  706. {
  707. "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
  708. "X-OpenWebUI-User-Id": user.id,
  709. "X-OpenWebUI-User-Email": user.email,
  710. "X-OpenWebUI-User-Role": user.role,
  711. }
  712. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  713. else {}
  714. ),
  715. },
  716. data=form_data.model_dump_json(exclude_none=True).encode(),
  717. )
  718. r.raise_for_status()
  719. log.debug(f"r.text: {r.text}")
  720. return True
  721. except Exception as e:
  722. log.exception(e)
  723. detail = None
  724. if r is not None:
  725. try:
  726. res = r.json()
  727. if "error" in res:
  728. detail = f"Ollama: {res['error']}"
  729. except Exception:
  730. detail = f"Ollama: {e}"
  731. raise HTTPException(
  732. status_code=r.status_code if r else 500,
  733. detail=detail if detail else "Open WebUI: Server Connection Error",
  734. )
  735. @router.delete("/api/delete")
  736. @router.delete("/api/delete/{url_idx}")
  737. async def delete_model(
  738. request: Request,
  739. form_data: ModelNameForm,
  740. url_idx: Optional[int] = None,
  741. user=Depends(get_admin_user),
  742. ):
  743. form_data = form_data.model_dump(exclude_none=True)
  744. form_data["model"] = form_data.get("model", form_data.get("name"))
  745. model = form_data.get("model")
  746. if url_idx is None:
  747. await get_all_models(request, user=user)
  748. models = request.app.state.OLLAMA_MODELS
  749. if model in models:
  750. url_idx = models[model]["urls"][0]
  751. else:
  752. raise HTTPException(
  753. status_code=400,
  754. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
  755. )
  756. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  757. key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  758. try:
  759. r = requests.request(
  760. method="DELETE",
  761. url=f"{url}/api/delete",
  762. data=json.dumps(form_data).encode(),
  763. headers={
  764. "Content-Type": "application/json",
  765. **({"Authorization": f"Bearer {key}"} if key else {}),
  766. **(
  767. {
  768. "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
  769. "X-OpenWebUI-User-Id": user.id,
  770. "X-OpenWebUI-User-Email": user.email,
  771. "X-OpenWebUI-User-Role": user.role,
  772. }
  773. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  774. else {}
  775. ),
  776. },
  777. )
  778. r.raise_for_status()
  779. log.debug(f"r.text: {r.text}")
  780. return True
  781. except Exception as e:
  782. log.exception(e)
  783. detail = None
  784. if r is not None:
  785. try:
  786. res = r.json()
  787. if "error" in res:
  788. detail = f"Ollama: {res['error']}"
  789. except Exception:
  790. detail = f"Ollama: {e}"
  791. raise HTTPException(
  792. status_code=r.status_code if r else 500,
  793. detail=detail if detail else "Open WebUI: Server Connection Error",
  794. )
  795. @router.post("/api/show")
  796. async def show_model_info(
  797. request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
  798. ):
  799. form_data = form_data.model_dump(exclude_none=True)
  800. form_data["model"] = form_data.get("model", form_data.get("name"))
  801. await get_all_models(request, user=user)
  802. models = request.app.state.OLLAMA_MODELS
  803. model = form_data.get("model")
  804. if model not in models:
  805. raise HTTPException(
  806. status_code=400,
  807. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
  808. )
  809. url_idx = random.choice(models[model]["urls"])
  810. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  811. key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  812. try:
  813. r = requests.request(
  814. method="POST",
  815. url=f"{url}/api/show",
  816. headers={
  817. "Content-Type": "application/json",
  818. **({"Authorization": f"Bearer {key}"} if key else {}),
  819. **(
  820. {
  821. "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
  822. "X-OpenWebUI-User-Id": user.id,
  823. "X-OpenWebUI-User-Email": user.email,
  824. "X-OpenWebUI-User-Role": user.role,
  825. }
  826. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  827. else {}
  828. ),
  829. },
  830. data=json.dumps(form_data).encode(),
  831. )
  832. r.raise_for_status()
  833. return r.json()
  834. except Exception as e:
  835. log.exception(e)
  836. detail = None
  837. if r is not None:
  838. try:
  839. res = r.json()
  840. if "error" in res:
  841. detail = f"Ollama: {res['error']}"
  842. except Exception:
  843. detail = f"Ollama: {e}"
  844. raise HTTPException(
  845. status_code=r.status_code if r else 500,
  846. detail=detail if detail else "Open WebUI: Server Connection Error",
  847. )
  848. class GenerateEmbedForm(BaseModel):
  849. model: str
  850. input: list[str] | str
  851. truncate: Optional[bool] = None
  852. options: Optional[dict] = None
  853. keep_alive: Optional[Union[int, str]] = None
  854. model_config = ConfigDict(
  855. extra="allow",
  856. )
  857. @router.post("/api/embed")
  858. @router.post("/api/embed/{url_idx}")
  859. async def embed(
  860. request: Request,
  861. form_data: GenerateEmbedForm,
  862. url_idx: Optional[int] = None,
  863. user=Depends(get_verified_user),
  864. ):
  865. log.info(f"generate_ollama_batch_embeddings {form_data}")
  866. if url_idx is None:
  867. await get_all_models(request, user=user)
  868. models = request.app.state.OLLAMA_MODELS
  869. model = form_data.model
  870. if ":" not in model:
  871. model = f"{model}:latest"
  872. if model in models:
  873. url_idx = random.choice(models[model]["urls"])
  874. else:
  875. raise HTTPException(
  876. status_code=400,
  877. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
  878. )
  879. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  880. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  881. str(url_idx),
  882. request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
  883. )
  884. key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  885. prefix_id = api_config.get("prefix_id", None)
  886. if prefix_id:
  887. form_data.model = form_data.model.replace(f"{prefix_id}.", "")
  888. try:
  889. r = requests.request(
  890. method="POST",
  891. url=f"{url}/api/embed",
  892. headers={
  893. "Content-Type": "application/json",
  894. **({"Authorization": f"Bearer {key}"} if key else {}),
  895. **(
  896. {
  897. "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
  898. "X-OpenWebUI-User-Id": user.id,
  899. "X-OpenWebUI-User-Email": user.email,
  900. "X-OpenWebUI-User-Role": user.role,
  901. }
  902. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  903. else {}
  904. ),
  905. },
  906. data=form_data.model_dump_json(exclude_none=True).encode(),
  907. )
  908. r.raise_for_status()
  909. data = r.json()
  910. return data
  911. except Exception as e:
  912. log.exception(e)
  913. detail = None
  914. if r is not None:
  915. try:
  916. res = r.json()
  917. if "error" in res:
  918. detail = f"Ollama: {res['error']}"
  919. except Exception:
  920. detail = f"Ollama: {e}"
  921. raise HTTPException(
  922. status_code=r.status_code if r else 500,
  923. detail=detail if detail else "Open WebUI: Server Connection Error",
  924. )
  925. class GenerateEmbeddingsForm(BaseModel):
  926. model: str
  927. prompt: str
  928. options: Optional[dict] = None
  929. keep_alive: Optional[Union[int, str]] = None
  930. @router.post("/api/embeddings")
  931. @router.post("/api/embeddings/{url_idx}")
  932. async def embeddings(
  933. request: Request,
  934. form_data: GenerateEmbeddingsForm,
  935. url_idx: Optional[int] = None,
  936. user=Depends(get_verified_user),
  937. ):
  938. log.info(f"generate_ollama_embeddings {form_data}")
  939. if url_idx is None:
  940. await get_all_models(request, user=user)
  941. models = request.app.state.OLLAMA_MODELS
  942. model = form_data.model
  943. if ":" not in model:
  944. model = f"{model}:latest"
  945. if model in models:
  946. url_idx = random.choice(models[model]["urls"])
  947. else:
  948. raise HTTPException(
  949. status_code=400,
  950. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
  951. )
  952. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  953. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  954. str(url_idx),
  955. request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
  956. )
  957. key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  958. prefix_id = api_config.get("prefix_id", None)
  959. if prefix_id:
  960. form_data.model = form_data.model.replace(f"{prefix_id}.", "")
  961. try:
  962. r = requests.request(
  963. method="POST",
  964. url=f"{url}/api/embeddings",
  965. headers={
  966. "Content-Type": "application/json",
  967. **({"Authorization": f"Bearer {key}"} if key else {}),
  968. **(
  969. {
  970. "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
  971. "X-OpenWebUI-User-Id": user.id,
  972. "X-OpenWebUI-User-Email": user.email,
  973. "X-OpenWebUI-User-Role": user.role,
  974. }
  975. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  976. else {}
  977. ),
  978. },
  979. data=form_data.model_dump_json(exclude_none=True).encode(),
  980. )
  981. r.raise_for_status()
  982. data = r.json()
  983. return data
  984. except Exception as e:
  985. log.exception(e)
  986. detail = None
  987. if r is not None:
  988. try:
  989. res = r.json()
  990. if "error" in res:
  991. detail = f"Ollama: {res['error']}"
  992. except Exception:
  993. detail = f"Ollama: {e}"
  994. raise HTTPException(
  995. status_code=r.status_code if r else 500,
  996. detail=detail if detail else "Open WebUI: Server Connection Error",
  997. )
  998. class GenerateCompletionForm(BaseModel):
  999. model: str
  1000. prompt: str
  1001. suffix: Optional[str] = None
  1002. images: Optional[list[str]] = None
  1003. format: Optional[Union[dict, str]] = None
  1004. options: Optional[dict] = None
  1005. system: Optional[str] = None
  1006. template: Optional[str] = None
  1007. context: Optional[list[int]] = None
  1008. stream: Optional[bool] = True
  1009. raw: Optional[bool] = None
  1010. keep_alive: Optional[Union[int, str]] = None
  1011. @router.post("/api/generate")
  1012. @router.post("/api/generate/{url_idx}")
  1013. async def generate_completion(
  1014. request: Request,
  1015. form_data: GenerateCompletionForm,
  1016. url_idx: Optional[int] = None,
  1017. user=Depends(get_verified_user),
  1018. ):
  1019. if url_idx is None:
  1020. await get_all_models(request, user=user)
  1021. models = request.app.state.OLLAMA_MODELS
  1022. model = form_data.model
  1023. if ":" not in model:
  1024. model = f"{model}:latest"
  1025. if model in models:
  1026. url_idx = random.choice(models[model]["urls"])
  1027. else:
  1028. raise HTTPException(
  1029. status_code=400,
  1030. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
  1031. )
  1032. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  1033. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  1034. str(url_idx),
  1035. request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
  1036. )
  1037. prefix_id = api_config.get("prefix_id", None)
  1038. if prefix_id:
  1039. form_data.model = form_data.model.replace(f"{prefix_id}.", "")
  1040. return await send_post_request(
  1041. url=f"{url}/api/generate",
  1042. payload=form_data.model_dump_json(exclude_none=True).encode(),
  1043. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  1044. user=user,
  1045. )
  1046. class ChatMessage(BaseModel):
  1047. role: str
  1048. content: Optional[str] = None
  1049. tool_calls: Optional[list[dict]] = None
  1050. images: Optional[list[str]] = None
  1051. @validator("content", pre=True)
  1052. @classmethod
  1053. def check_at_least_one_field(cls, field_value, values, **kwargs):
  1054. # Raise an error if both 'content' and 'tool_calls' are None
  1055. if field_value is None and (
  1056. "tool_calls" not in values or values["tool_calls"] is None
  1057. ):
  1058. raise ValueError(
  1059. "At least one of 'content' or 'tool_calls' must be provided"
  1060. )
  1061. return field_value
  1062. class GenerateChatCompletionForm(BaseModel):
  1063. model: str
  1064. messages: list[ChatMessage]
  1065. format: Optional[Union[dict, str]] = None
  1066. options: Optional[dict] = None
  1067. template: Optional[str] = None
  1068. stream: Optional[bool] = True
  1069. keep_alive: Optional[Union[int, str]] = None
  1070. tools: Optional[list[dict]] = None
  1071. model_config = ConfigDict(
  1072. extra="allow",
  1073. )
  1074. async def get_ollama_url(request: Request, model: str, url_idx: Optional[int] = None):
  1075. if url_idx is None:
  1076. models = request.app.state.OLLAMA_MODELS
  1077. if model not in models:
  1078. raise HTTPException(
  1079. status_code=400,
  1080. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
  1081. )
  1082. url_idx = random.choice(models[model].get("urls", []))
  1083. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  1084. return url, url_idx
  1085. @router.post("/api/chat")
  1086. @router.post("/api/chat/{url_idx}")
  1087. async def generate_chat_completion(
  1088. request: Request,
  1089. form_data: dict,
  1090. url_idx: Optional[int] = None,
  1091. user=Depends(get_verified_user),
  1092. bypass_filter: Optional[bool] = False,
  1093. ):
  1094. if BYPASS_MODEL_ACCESS_CONTROL:
  1095. bypass_filter = True
  1096. metadata = form_data.pop("metadata", None)
  1097. try:
  1098. form_data = GenerateChatCompletionForm(**form_data)
  1099. except Exception as e:
  1100. log.exception(e)
  1101. raise HTTPException(
  1102. status_code=400,
  1103. detail=str(e),
  1104. )
  1105. if isinstance(form_data, BaseModel):
  1106. payload = {**form_data.model_dump(exclude_none=True)}
  1107. if "metadata" in payload:
  1108. del payload["metadata"]
  1109. model_id = payload["model"]
  1110. model_info = Models.get_model_by_id(model_id)
  1111. if model_info:
  1112. if model_info.base_model_id:
  1113. payload["model"] = model_info.base_model_id
  1114. params = model_info.params.model_dump()
  1115. if params:
  1116. system = params.pop("system", None)
  1117. payload = apply_model_params_to_body_ollama(params, payload)
  1118. payload = apply_system_prompt_to_body(system, payload, metadata, user)
  1119. # Check if user has access to the model
  1120. if not bypass_filter and user.role == "user":
  1121. if not (
  1122. user.id == model_info.user_id
  1123. or has_access(
  1124. user.id, type="read", access_control=model_info.access_control
  1125. )
  1126. ):
  1127. raise HTTPException(
  1128. status_code=403,
  1129. detail="Model not found",
  1130. )
  1131. elif not bypass_filter:
  1132. if user.role != "admin":
  1133. raise HTTPException(
  1134. status_code=403,
  1135. detail="Model not found",
  1136. )
  1137. if ":" not in payload["model"]:
  1138. payload["model"] = f"{payload['model']}:latest"
  1139. url, url_idx = await get_ollama_url(request, payload["model"], url_idx)
  1140. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  1141. str(url_idx),
  1142. request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
  1143. )
  1144. prefix_id = api_config.get("prefix_id", None)
  1145. if prefix_id:
  1146. payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
  1147. return await send_post_request(
  1148. url=f"{url}/api/chat",
  1149. payload=json.dumps(payload),
  1150. stream=form_data.stream,
  1151. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  1152. content_type="application/x-ndjson",
  1153. user=user,
  1154. metadata=metadata,
  1155. )
  1156. # TODO: we should update this part once Ollama supports other types
  1157. class OpenAIChatMessageContent(BaseModel):
  1158. type: str
  1159. model_config = ConfigDict(extra="allow")
  1160. class OpenAIChatMessage(BaseModel):
  1161. role: str
  1162. content: Union[Optional[str], list[OpenAIChatMessageContent]]
  1163. model_config = ConfigDict(extra="allow")
  1164. class OpenAIChatCompletionForm(BaseModel):
  1165. model: str
  1166. messages: list[OpenAIChatMessage]
  1167. model_config = ConfigDict(extra="allow")
  1168. class OpenAICompletionForm(BaseModel):
  1169. model: str
  1170. prompt: str
  1171. model_config = ConfigDict(extra="allow")
  1172. @router.post("/v1/completions")
  1173. @router.post("/v1/completions/{url_idx}")
  1174. async def generate_openai_completion(
  1175. request: Request,
  1176. form_data: dict,
  1177. url_idx: Optional[int] = None,
  1178. user=Depends(get_verified_user),
  1179. ):
  1180. metadata = form_data.pop("metadata", None)
  1181. try:
  1182. form_data = OpenAICompletionForm(**form_data)
  1183. except Exception as e:
  1184. log.exception(e)
  1185. raise HTTPException(
  1186. status_code=400,
  1187. detail=str(e),
  1188. )
  1189. payload = {**form_data.model_dump(exclude_none=True, exclude=["metadata"])}
  1190. if "metadata" in payload:
  1191. del payload["metadata"]
  1192. model_id = form_data.model
  1193. if ":" not in model_id:
  1194. model_id = f"{model_id}:latest"
  1195. model_info = Models.get_model_by_id(model_id)
  1196. if model_info:
  1197. if model_info.base_model_id:
  1198. payload["model"] = model_info.base_model_id
  1199. params = model_info.params.model_dump()
  1200. if params:
  1201. payload = apply_model_params_to_body_openai(params, payload)
  1202. # Check if user has access to the model
  1203. if user.role == "user":
  1204. if not (
  1205. user.id == model_info.user_id
  1206. or has_access(
  1207. user.id, type="read", access_control=model_info.access_control
  1208. )
  1209. ):
  1210. raise HTTPException(
  1211. status_code=403,
  1212. detail="Model not found",
  1213. )
  1214. else:
  1215. if user.role != "admin":
  1216. raise HTTPException(
  1217. status_code=403,
  1218. detail="Model not found",
  1219. )
  1220. if ":" not in payload["model"]:
  1221. payload["model"] = f"{payload['model']}:latest"
  1222. url, url_idx = await get_ollama_url(request, payload["model"], url_idx)
  1223. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  1224. str(url_idx),
  1225. request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
  1226. )
  1227. prefix_id = api_config.get("prefix_id", None)
  1228. if prefix_id:
  1229. payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
  1230. return await send_post_request(
  1231. url=f"{url}/v1/completions",
  1232. payload=json.dumps(payload),
  1233. stream=payload.get("stream", False),
  1234. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  1235. user=user,
  1236. metadata=metadata,
  1237. )
  1238. @router.post("/v1/chat/completions")
  1239. @router.post("/v1/chat/completions/{url_idx}")
  1240. async def generate_openai_chat_completion(
  1241. request: Request,
  1242. form_data: dict,
  1243. url_idx: Optional[int] = None,
  1244. user=Depends(get_verified_user),
  1245. ):
  1246. metadata = form_data.pop("metadata", None)
  1247. try:
  1248. completion_form = OpenAIChatCompletionForm(**form_data)
  1249. except Exception as e:
  1250. log.exception(e)
  1251. raise HTTPException(
  1252. status_code=400,
  1253. detail=str(e),
  1254. )
  1255. payload = {**completion_form.model_dump(exclude_none=True, exclude=["metadata"])}
  1256. if "metadata" in payload:
  1257. del payload["metadata"]
  1258. model_id = completion_form.model
  1259. if ":" not in model_id:
  1260. model_id = f"{model_id}:latest"
  1261. model_info = Models.get_model_by_id(model_id)
  1262. if model_info:
  1263. if model_info.base_model_id:
  1264. payload["model"] = model_info.base_model_id
  1265. params = model_info.params.model_dump()
  1266. if params:
  1267. system = params.pop("system", None)
  1268. payload = apply_model_params_to_body_openai(params, payload)
  1269. payload = apply_system_prompt_to_body(system, payload, metadata, user)
  1270. # Check if user has access to the model
  1271. if user.role == "user":
  1272. if not (
  1273. user.id == model_info.user_id
  1274. or has_access(
  1275. user.id, type="read", access_control=model_info.access_control
  1276. )
  1277. ):
  1278. raise HTTPException(
  1279. status_code=403,
  1280. detail="Model not found",
  1281. )
  1282. else:
  1283. if user.role != "admin":
  1284. raise HTTPException(
  1285. status_code=403,
  1286. detail="Model not found",
  1287. )
  1288. if ":" not in payload["model"]:
  1289. payload["model"] = f"{payload['model']}:latest"
  1290. url, url_idx = await get_ollama_url(request, payload["model"], url_idx)
  1291. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  1292. str(url_idx),
  1293. request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
  1294. )
  1295. prefix_id = api_config.get("prefix_id", None)
  1296. if prefix_id:
  1297. payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
  1298. return await send_post_request(
  1299. url=f"{url}/v1/chat/completions",
  1300. payload=json.dumps(payload),
  1301. stream=payload.get("stream", False),
  1302. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  1303. user=user,
  1304. metadata=metadata,
  1305. )
  1306. @router.get("/v1/models")
  1307. @router.get("/v1/models/{url_idx}")
  1308. async def get_openai_models(
  1309. request: Request,
  1310. url_idx: Optional[int] = None,
  1311. user=Depends(get_verified_user),
  1312. ):
  1313. models = []
  1314. if url_idx is None:
  1315. model_list = await get_all_models(request, user=user)
  1316. models = [
  1317. {
  1318. "id": model["model"],
  1319. "object": "model",
  1320. "created": int(time.time()),
  1321. "owned_by": "openai",
  1322. }
  1323. for model in model_list["models"]
  1324. ]
  1325. else:
  1326. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  1327. try:
  1328. r = requests.request(method="GET", url=f"{url}/api/tags")
  1329. r.raise_for_status()
  1330. model_list = r.json()
  1331. models = [
  1332. {
  1333. "id": model["model"],
  1334. "object": "model",
  1335. "created": int(time.time()),
  1336. "owned_by": "openai",
  1337. }
  1338. for model in models["models"]
  1339. ]
  1340. except Exception as e:
  1341. log.exception(e)
  1342. error_detail = "Open WebUI: Server Connection Error"
  1343. if r is not None:
  1344. try:
  1345. res = r.json()
  1346. if "error" in res:
  1347. error_detail = f"Ollama: {res['error']}"
  1348. except Exception:
  1349. error_detail = f"Ollama: {e}"
  1350. raise HTTPException(
  1351. status_code=r.status_code if r else 500,
  1352. detail=error_detail,
  1353. )
  1354. if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
  1355. # Filter models based on user access control
  1356. filtered_models = []
  1357. for model in models:
  1358. model_info = Models.get_model_by_id(model["id"])
  1359. if model_info:
  1360. if user.id == model_info.user_id or has_access(
  1361. user.id, type="read", access_control=model_info.access_control
  1362. ):
  1363. filtered_models.append(model)
  1364. models = filtered_models
  1365. return {
  1366. "data": models,
  1367. "object": "list",
  1368. }
  1369. class UrlForm(BaseModel):
  1370. url: str
  1371. class UploadBlobForm(BaseModel):
  1372. filename: str
  1373. def parse_huggingface_url(hf_url):
  1374. try:
  1375. # Parse the URL
  1376. parsed_url = urlparse(hf_url)
  1377. # Get the path and split it into components
  1378. path_components = parsed_url.path.split("/")
  1379. # Extract the desired output
  1380. model_file = path_components[-1]
  1381. return model_file
  1382. except ValueError:
  1383. return None
  1384. async def download_file_stream(
  1385. ollama_url, file_url, file_path, file_name, chunk_size=1024 * 1024
  1386. ):
  1387. done = False
  1388. if os.path.exists(file_path):
  1389. current_size = os.path.getsize(file_path)
  1390. else:
  1391. current_size = 0
  1392. headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {}
  1393. timeout = aiohttp.ClientTimeout(total=600) # Set the timeout
  1394. async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
  1395. async with session.get(
  1396. file_url, headers=headers, ssl=AIOHTTP_CLIENT_SESSION_SSL
  1397. ) as response:
  1398. total_size = int(response.headers.get("content-length", 0)) + current_size
  1399. with open(file_path, "ab+") as file:
  1400. async for data in response.content.iter_chunked(chunk_size):
  1401. current_size += len(data)
  1402. file.write(data)
  1403. done = current_size == total_size
  1404. progress = round((current_size / total_size) * 100, 2)
  1405. yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n'
  1406. if done:
  1407. file.close()
  1408. with open(file_path, "rb") as file:
  1409. chunk_size = 1024 * 1024 * 2
  1410. hashed = calculate_sha256(file, chunk_size)
  1411. url = f"{ollama_url}/api/blobs/sha256:{hashed}"
  1412. with requests.Session() as session:
  1413. response = session.post(url, data=file, timeout=30)
  1414. if response.ok:
  1415. res = {
  1416. "done": done,
  1417. "blob": f"sha256:{hashed}",
  1418. "name": file_name,
  1419. }
  1420. os.remove(file_path)
  1421. yield f"data: {json.dumps(res)}\n\n"
  1422. else:
  1423. raise "Ollama: Could not create blob, Please try again."
  1424. # url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
  1425. @router.post("/models/download")
  1426. @router.post("/models/download/{url_idx}")
  1427. async def download_model(
  1428. request: Request,
  1429. form_data: UrlForm,
  1430. url_idx: Optional[int] = None,
  1431. user=Depends(get_admin_user),
  1432. ):
  1433. allowed_hosts = ["https://huggingface.co/", "https://github.com/"]
  1434. if not any(form_data.url.startswith(host) for host in allowed_hosts):
  1435. raise HTTPException(
  1436. status_code=400,
  1437. detail="Invalid file_url. Only URLs from allowed hosts are permitted.",
  1438. )
  1439. if url_idx is None:
  1440. url_idx = 0
  1441. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  1442. file_name = parse_huggingface_url(form_data.url)
  1443. if file_name:
  1444. file_path = f"{UPLOAD_DIR}/{file_name}"
  1445. return StreamingResponse(
  1446. download_file_stream(url, form_data.url, file_path, file_name),
  1447. )
  1448. else:
  1449. return None
  1450. # TODO: Progress bar does not reflect size & duration of upload.
  1451. @router.post("/models/upload")
  1452. @router.post("/models/upload/{url_idx}")
  1453. async def upload_model(
  1454. request: Request,
  1455. file: UploadFile = File(...),
  1456. url_idx: Optional[int] = None,
  1457. user=Depends(get_admin_user),
  1458. ):
  1459. if url_idx is None:
  1460. url_idx = 0
  1461. ollama_url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  1462. filename = os.path.basename(file.filename)
  1463. file_path = os.path.join(UPLOAD_DIR, filename)
  1464. os.makedirs(UPLOAD_DIR, exist_ok=True)
  1465. # --- P1: save file locally ---
  1466. chunk_size = 1024 * 1024 * 2 # 2 MB chunks
  1467. with open(file_path, "wb") as out_f:
  1468. while True:
  1469. chunk = file.file.read(chunk_size)
  1470. # log.info(f"Chunk: {str(chunk)}") # DEBUG
  1471. if not chunk:
  1472. break
  1473. out_f.write(chunk)
  1474. async def file_process_stream():
  1475. nonlocal ollama_url
  1476. total_size = os.path.getsize(file_path)
  1477. log.info(f"Total Model Size: {str(total_size)}") # DEBUG
  1478. # --- P2: SSE progress + calculate sha256 hash ---
  1479. file_hash = calculate_sha256(file_path, chunk_size)
  1480. log.info(f"Model Hash: {str(file_hash)}") # DEBUG
  1481. try:
  1482. with open(file_path, "rb") as f:
  1483. bytes_read = 0
  1484. while chunk := f.read(chunk_size):
  1485. bytes_read += len(chunk)
  1486. progress = round(bytes_read / total_size * 100, 2)
  1487. data_msg = {
  1488. "progress": progress,
  1489. "total": total_size,
  1490. "completed": bytes_read,
  1491. }
  1492. yield f"data: {json.dumps(data_msg)}\n\n"
  1493. # --- P3: Upload to ollama /api/blobs ---
  1494. with open(file_path, "rb") as f:
  1495. url = f"{ollama_url}/api/blobs/sha256:{file_hash}"
  1496. response = requests.post(url, data=f)
  1497. if response.ok:
  1498. log.info(f"Uploaded to /api/blobs") # DEBUG
  1499. # Remove local file
  1500. os.remove(file_path)
  1501. # Create model in ollama
  1502. model_name, ext = os.path.splitext(filename)
  1503. log.info(f"Created Model: {model_name}") # DEBUG
  1504. create_payload = {
  1505. "model": model_name,
  1506. # Reference the file by its original name => the uploaded blob's digest
  1507. "files": {filename: f"sha256:{file_hash}"},
  1508. }
  1509. log.info(f"Model Payload: {create_payload}") # DEBUG
  1510. # Call ollama /api/create
  1511. # https://github.com/ollama/ollama/blob/main/docs/api.md#create-a-model
  1512. create_resp = requests.post(
  1513. url=f"{ollama_url}/api/create",
  1514. headers={"Content-Type": "application/json"},
  1515. data=json.dumps(create_payload),
  1516. )
  1517. if create_resp.ok:
  1518. log.info(f"API SUCCESS!") # DEBUG
  1519. done_msg = {
  1520. "done": True,
  1521. "blob": f"sha256:{file_hash}",
  1522. "name": filename,
  1523. "model_created": model_name,
  1524. }
  1525. yield f"data: {json.dumps(done_msg)}\n\n"
  1526. else:
  1527. raise Exception(
  1528. f"Failed to create model in Ollama. {create_resp.text}"
  1529. )
  1530. else:
  1531. raise Exception("Ollama: Could not create blob, Please try again.")
  1532. except Exception as e:
  1533. res = {"error": str(e)}
  1534. yield f"data: {json.dumps(res)}\n\n"
  1535. return StreamingResponse(file_process_stream(), media_type="text/event-stream")