ollama.py 56 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800
  1. # TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
  2. # Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
  3. # least connections, or least response time for better resource utilization and performance optimization.
  4. import asyncio
  5. import json
  6. import logging
  7. import os
  8. import random
  9. import re
  10. import time
  11. from datetime import datetime
  12. from typing import Optional, Union
  13. from urllib.parse import urlparse
  14. import aiohttp
  15. from aiocache import cached
  16. import requests
  17. from open_webui.models.users import UserModel
  18. from open_webui.env import (
  19. ENABLE_FORWARD_USER_INFO_HEADERS,
  20. )
  21. from fastapi import (
  22. Depends,
  23. FastAPI,
  24. File,
  25. HTTPException,
  26. Request,
  27. UploadFile,
  28. APIRouter,
  29. )
  30. from fastapi.middleware.cors import CORSMiddleware
  31. from fastapi.responses import StreamingResponse
  32. from pydantic import BaseModel, ConfigDict, validator
  33. from starlette.background import BackgroundTask
  34. from open_webui.models.models import Models
  35. from open_webui.utils.misc import (
  36. calculate_sha256,
  37. )
  38. from open_webui.utils.payload import (
  39. apply_model_params_to_body_ollama,
  40. apply_model_params_to_body_openai,
  41. apply_model_system_prompt_to_body,
  42. )
  43. from open_webui.utils.auth import get_admin_user, get_verified_user
  44. from open_webui.utils.access_control import has_access
  45. from open_webui.config import (
  46. UPLOAD_DIR,
  47. )
  48. from open_webui.env import (
  49. ENV,
  50. SRC_LOG_LEVELS,
  51. AIOHTTP_CLIENT_SESSION_SSL,
  52. AIOHTTP_CLIENT_TIMEOUT,
  53. AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
  54. BYPASS_MODEL_ACCESS_CONTROL,
  55. )
  56. from open_webui.constants import ERROR_MESSAGES
  57. log = logging.getLogger(__name__)
  58. log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
  59. ##########################################
  60. #
  61. # Utility functions
  62. #
  63. ##########################################
  64. async def send_get_request(url, key=None, user: UserModel = None):
  65. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
  66. try:
  67. async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
  68. async with session.get(
  69. url,
  70. headers={
  71. "Content-Type": "application/json",
  72. **({"Authorization": f"Bearer {key}"} if key else {}),
  73. **(
  74. {
  75. "X-OpenWebUI-User-Name": user.name,
  76. "X-OpenWebUI-User-Id": user.id,
  77. "X-OpenWebUI-User-Email": user.email,
  78. "X-OpenWebUI-User-Role": user.role,
  79. }
  80. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  81. else {}
  82. ),
  83. },
  84. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  85. ) as response:
  86. return await response.json()
  87. except Exception as e:
  88. # Handle connection error here
  89. log.error(f"Connection error: {e}")
  90. return None
  91. async def cleanup_response(
  92. response: Optional[aiohttp.ClientResponse],
  93. session: Optional[aiohttp.ClientSession],
  94. ):
  95. if response:
  96. response.close()
  97. if session:
  98. await session.close()
  99. async def send_post_request(
  100. url: str,
  101. payload: Union[str, bytes],
  102. stream: bool = True,
  103. key: Optional[str] = None,
  104. content_type: Optional[str] = None,
  105. user: UserModel = None,
  106. ):
  107. r = None
  108. try:
  109. session = aiohttp.ClientSession(
  110. trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
  111. )
  112. r = await session.post(
  113. url,
  114. data=payload,
  115. headers={
  116. "Content-Type": "application/json",
  117. **({"Authorization": f"Bearer {key}"} if key else {}),
  118. **(
  119. {
  120. "X-OpenWebUI-User-Name": user.name,
  121. "X-OpenWebUI-User-Id": user.id,
  122. "X-OpenWebUI-User-Email": user.email,
  123. "X-OpenWebUI-User-Role": user.role,
  124. }
  125. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  126. else {}
  127. ),
  128. },
  129. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  130. )
  131. r.raise_for_status()
  132. if stream:
  133. response_headers = dict(r.headers)
  134. if content_type:
  135. response_headers["Content-Type"] = content_type
  136. return StreamingResponse(
  137. r.content,
  138. status_code=r.status,
  139. headers=response_headers,
  140. background=BackgroundTask(
  141. cleanup_response, response=r, session=session
  142. ),
  143. )
  144. else:
  145. res = await r.json()
  146. await cleanup_response(r, session)
  147. return res
  148. except Exception as e:
  149. detail = None
  150. if r is not None:
  151. try:
  152. res = await r.json()
  153. if "error" in res:
  154. detail = f"Ollama: {res.get('error', 'Unknown error')}"
  155. except Exception:
  156. detail = f"Ollama: {e}"
  157. raise HTTPException(
  158. status_code=r.status if r else 500,
  159. detail=detail if detail else "Open WebUI: Server Connection Error",
  160. )
  161. def get_api_key(idx, url, configs):
  162. parsed_url = urlparse(url)
  163. base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
  164. return configs.get(str(idx), configs.get(base_url, {})).get(
  165. "key", None
  166. ) # Legacy support
  167. ##########################################
  168. #
  169. # API routes
  170. #
  171. ##########################################
  172. router = APIRouter()
  173. @router.head("/")
  174. @router.get("/")
  175. async def get_status():
  176. return {"status": True}
  177. class ConnectionVerificationForm(BaseModel):
  178. url: str
  179. key: Optional[str] = None
  180. @router.post("/verify")
  181. async def verify_connection(
  182. form_data: ConnectionVerificationForm, user=Depends(get_admin_user)
  183. ):
  184. url = form_data.url
  185. key = form_data.key
  186. async with aiohttp.ClientSession(
  187. trust_env=True,
  188. timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
  189. ) as session:
  190. try:
  191. async with session.get(
  192. f"{url}/api/version",
  193. headers={
  194. **({"Authorization": f"Bearer {key}"} if key else {}),
  195. **(
  196. {
  197. "X-OpenWebUI-User-Name": user.name,
  198. "X-OpenWebUI-User-Id": user.id,
  199. "X-OpenWebUI-User-Email": user.email,
  200. "X-OpenWebUI-User-Role": user.role,
  201. }
  202. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  203. else {}
  204. ),
  205. },
  206. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  207. ) as r:
  208. if r.status != 200:
  209. detail = f"HTTP Error: {r.status}"
  210. res = await r.json()
  211. if "error" in res:
  212. detail = f"External Error: {res['error']}"
  213. raise Exception(detail)
  214. data = await r.json()
  215. return data
  216. except aiohttp.ClientError as e:
  217. log.exception(f"Client error: {str(e)}")
  218. raise HTTPException(
  219. status_code=500, detail="Open WebUI: Server Connection Error"
  220. )
  221. except Exception as e:
  222. log.exception(f"Unexpected error: {e}")
  223. error_detail = f"Unexpected error: {str(e)}"
  224. raise HTTPException(status_code=500, detail=error_detail)
  225. @router.get("/config")
  226. async def get_config(request: Request, user=Depends(get_admin_user)):
  227. return {
  228. "ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API,
  229. "OLLAMA_BASE_URLS": request.app.state.config.OLLAMA_BASE_URLS,
  230. "OLLAMA_API_CONFIGS": request.app.state.config.OLLAMA_API_CONFIGS,
  231. }
  232. class OllamaConfigForm(BaseModel):
  233. ENABLE_OLLAMA_API: Optional[bool] = None
  234. OLLAMA_BASE_URLS: list[str]
  235. OLLAMA_API_CONFIGS: dict
  236. @router.post("/config/update")
  237. async def update_config(
  238. request: Request, form_data: OllamaConfigForm, user=Depends(get_admin_user)
  239. ):
  240. request.app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API
  241. request.app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS
  242. request.app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS
  243. # Remove the API configs that are not in the API URLS
  244. keys = list(map(str, range(len(request.app.state.config.OLLAMA_BASE_URLS))))
  245. request.app.state.config.OLLAMA_API_CONFIGS = {
  246. key: value
  247. for key, value in request.app.state.config.OLLAMA_API_CONFIGS.items()
  248. if key in keys
  249. }
  250. return {
  251. "ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API,
  252. "OLLAMA_BASE_URLS": request.app.state.config.OLLAMA_BASE_URLS,
  253. "OLLAMA_API_CONFIGS": request.app.state.config.OLLAMA_API_CONFIGS,
  254. }
  255. def merge_ollama_models_lists(model_lists):
  256. merged_models = {}
  257. for idx, model_list in enumerate(model_lists):
  258. if model_list is not None:
  259. for model in model_list:
  260. id = model["model"]
  261. if id not in merged_models:
  262. model["urls"] = [idx]
  263. merged_models[id] = model
  264. else:
  265. merged_models[id]["urls"].append(idx)
  266. return list(merged_models.values())
  267. @cached(ttl=1)
  268. async def get_all_models(request: Request, user: UserModel = None):
  269. log.info("get_all_models()")
  270. if request.app.state.config.ENABLE_OLLAMA_API:
  271. request_tasks = []
  272. for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
  273. if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and (
  274. url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support
  275. ):
  276. request_tasks.append(send_get_request(f"{url}/api/tags", user=user))
  277. else:
  278. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  279. str(idx),
  280. request.app.state.config.OLLAMA_API_CONFIGS.get(
  281. url, {}
  282. ), # Legacy support
  283. )
  284. enable = api_config.get("enable", True)
  285. key = api_config.get("key", None)
  286. if enable:
  287. request_tasks.append(
  288. send_get_request(f"{url}/api/tags", key, user=user)
  289. )
  290. else:
  291. request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
  292. responses = await asyncio.gather(*request_tasks)
  293. for idx, response in enumerate(responses):
  294. if response:
  295. url = request.app.state.config.OLLAMA_BASE_URLS[idx]
  296. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  297. str(idx),
  298. request.app.state.config.OLLAMA_API_CONFIGS.get(
  299. url, {}
  300. ), # Legacy support
  301. )
  302. connection_type = api_config.get("connection_type", "local")
  303. prefix_id = api_config.get("prefix_id", None)
  304. tags = api_config.get("tags", [])
  305. model_ids = api_config.get("model_ids", [])
  306. if len(model_ids) != 0 and "models" in response:
  307. response["models"] = list(
  308. filter(
  309. lambda model: model["model"] in model_ids,
  310. response["models"],
  311. )
  312. )
  313. for model in response.get("models", []):
  314. if prefix_id:
  315. model["model"] = f"{prefix_id}.{model['model']}"
  316. if tags:
  317. model["tags"] = tags
  318. if connection_type:
  319. model["connection_type"] = connection_type
  320. models = {
  321. "models": merge_ollama_models_lists(
  322. map(
  323. lambda response: response.get("models", []) if response else None,
  324. responses,
  325. )
  326. )
  327. }
  328. try:
  329. loaded_models = await get_ollama_loaded_models(request, user=user)
  330. expires_map = {
  331. m["name"]: m["expires_at"]
  332. for m in loaded_models["models"]
  333. if "expires_at" in m
  334. }
  335. for m in models["models"]:
  336. if m["name"] in expires_map:
  337. # Parse ISO8601 datetime with offset, get unix timestamp as int
  338. dt = datetime.fromisoformat(expires_map[m["name"]])
  339. m["expires_at"] = int(dt.timestamp())
  340. except Exception as e:
  341. log.debug(f"Failed to get loaded models: {e}")
  342. else:
  343. models = {"models": []}
  344. request.app.state.OLLAMA_MODELS = {
  345. model["model"]: model for model in models["models"]
  346. }
  347. return models
  348. async def get_filtered_models(models, user):
  349. # Filter models based on user access control
  350. filtered_models = []
  351. for model in models.get("models", []):
  352. model_info = Models.get_model_by_id(model["model"])
  353. if model_info:
  354. if user.id == model_info.user_id or has_access(
  355. user.id, type="read", access_control=model_info.access_control
  356. ):
  357. filtered_models.append(model)
  358. return filtered_models
  359. @router.get("/api/tags")
  360. @router.get("/api/tags/{url_idx}")
  361. async def get_ollama_tags(
  362. request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user)
  363. ):
  364. models = []
  365. if url_idx is None:
  366. models = await get_all_models(request, user=user)
  367. else:
  368. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  369. key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  370. r = None
  371. try:
  372. r = requests.request(
  373. method="GET",
  374. url=f"{url}/api/tags",
  375. headers={
  376. **({"Authorization": f"Bearer {key}"} if key else {}),
  377. **(
  378. {
  379. "X-OpenWebUI-User-Name": user.name,
  380. "X-OpenWebUI-User-Id": user.id,
  381. "X-OpenWebUI-User-Email": user.email,
  382. "X-OpenWebUI-User-Role": user.role,
  383. }
  384. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  385. else {}
  386. ),
  387. },
  388. )
  389. r.raise_for_status()
  390. models = r.json()
  391. except Exception as e:
  392. log.exception(e)
  393. detail = None
  394. if r is not None:
  395. try:
  396. res = r.json()
  397. if "error" in res:
  398. detail = f"Ollama: {res['error']}"
  399. except Exception:
  400. detail = f"Ollama: {e}"
  401. raise HTTPException(
  402. status_code=r.status_code if r else 500,
  403. detail=detail if detail else "Open WebUI: Server Connection Error",
  404. )
  405. if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
  406. models["models"] = await get_filtered_models(models, user)
  407. return models
  408. @router.get("/api/ps")
  409. async def get_ollama_loaded_models(request: Request, user=Depends(get_admin_user)):
  410. """
  411. List models that are currently loaded into Ollama memory, and which node they are loaded on.
  412. """
  413. if request.app.state.config.ENABLE_OLLAMA_API:
  414. request_tasks = []
  415. for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
  416. if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and (
  417. url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support
  418. ):
  419. request_tasks.append(send_get_request(f"{url}/api/ps", user=user))
  420. else:
  421. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  422. str(idx),
  423. request.app.state.config.OLLAMA_API_CONFIGS.get(
  424. url, {}
  425. ), # Legacy support
  426. )
  427. enable = api_config.get("enable", True)
  428. key = api_config.get("key", None)
  429. if enable:
  430. request_tasks.append(
  431. send_get_request(f"{url}/api/ps", key, user=user)
  432. )
  433. else:
  434. request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
  435. responses = await asyncio.gather(*request_tasks)
  436. for idx, response in enumerate(responses):
  437. if response:
  438. url = request.app.state.config.OLLAMA_BASE_URLS[idx]
  439. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  440. str(idx),
  441. request.app.state.config.OLLAMA_API_CONFIGS.get(
  442. url, {}
  443. ), # Legacy support
  444. )
  445. prefix_id = api_config.get("prefix_id", None)
  446. for model in response.get("models", []):
  447. if prefix_id:
  448. model["model"] = f"{prefix_id}.{model['model']}"
  449. models = {
  450. "models": merge_ollama_models_lists(
  451. map(
  452. lambda response: response.get("models", []) if response else None,
  453. responses,
  454. )
  455. )
  456. }
  457. else:
  458. models = {"models": []}
  459. return models
  460. @router.get("/api/version")
  461. @router.get("/api/version/{url_idx}")
  462. async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
  463. if request.app.state.config.ENABLE_OLLAMA_API:
  464. if url_idx is None:
  465. # returns lowest version
  466. request_tasks = []
  467. for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
  468. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  469. str(idx),
  470. request.app.state.config.OLLAMA_API_CONFIGS.get(
  471. url, {}
  472. ), # Legacy support
  473. )
  474. enable = api_config.get("enable", True)
  475. key = api_config.get("key", None)
  476. if enable:
  477. request_tasks.append(
  478. send_get_request(
  479. f"{url}/api/version",
  480. key,
  481. )
  482. )
  483. responses = await asyncio.gather(*request_tasks)
  484. responses = list(filter(lambda x: x is not None, responses))
  485. if len(responses) > 0:
  486. lowest_version = min(
  487. responses,
  488. key=lambda x: tuple(
  489. map(int, re.sub(r"^v|-.*", "", x["version"]).split("."))
  490. ),
  491. )
  492. return {"version": lowest_version["version"]}
  493. else:
  494. raise HTTPException(
  495. status_code=500,
  496. detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
  497. )
  498. else:
  499. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  500. r = None
  501. try:
  502. r = requests.request(method="GET", url=f"{url}/api/version")
  503. r.raise_for_status()
  504. return r.json()
  505. except Exception as e:
  506. log.exception(e)
  507. detail = None
  508. if r is not None:
  509. try:
  510. res = r.json()
  511. if "error" in res:
  512. detail = f"Ollama: {res['error']}"
  513. except Exception:
  514. detail = f"Ollama: {e}"
  515. raise HTTPException(
  516. status_code=r.status_code if r else 500,
  517. detail=detail if detail else "Open WebUI: Server Connection Error",
  518. )
  519. else:
  520. return {"version": False}
  521. class ModelNameForm(BaseModel):
  522. name: str
  523. @router.post("/api/unload")
  524. async def unload_model(
  525. request: Request,
  526. form_data: ModelNameForm,
  527. user=Depends(get_admin_user),
  528. ):
  529. model_name = form_data.name
  530. if not model_name:
  531. raise HTTPException(
  532. status_code=400, detail="Missing 'name' of model to unload."
  533. )
  534. # Refresh/load models if needed, get mapping from name to URLs
  535. await get_all_models(request, user=user)
  536. models = request.app.state.OLLAMA_MODELS
  537. # Canonicalize model name (if not supplied with version)
  538. if ":" not in model_name:
  539. model_name = f"{model_name}:latest"
  540. if model_name not in models:
  541. raise HTTPException(
  542. status_code=400, detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model_name)
  543. )
  544. url_indices = models[model_name]["urls"]
  545. # Send unload to ALL url_indices
  546. results = []
  547. errors = []
  548. for idx in url_indices:
  549. url = request.app.state.config.OLLAMA_BASE_URLS[idx]
  550. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  551. str(idx), request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
  552. )
  553. key = get_api_key(idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  554. prefix_id = api_config.get("prefix_id", None)
  555. if prefix_id and model_name.startswith(f"{prefix_id}."):
  556. model_name = model_name[len(f"{prefix_id}.") :]
  557. payload = {"model": model_name, "keep_alive": 0, "prompt": ""}
  558. try:
  559. res = await send_post_request(
  560. url=f"{url}/api/generate",
  561. payload=json.dumps(payload),
  562. stream=False,
  563. key=key,
  564. user=user,
  565. )
  566. results.append({"url_idx": idx, "success": True, "response": res})
  567. except Exception as e:
  568. log.exception(f"Failed to unload model on node {idx}: {e}")
  569. errors.append({"url_idx": idx, "success": False, "error": str(e)})
  570. if len(errors) > 0:
  571. raise HTTPException(
  572. status_code=500,
  573. detail=f"Failed to unload model on {len(errors)} nodes: {errors}",
  574. )
  575. return {"status": True}
  576. @router.post("/api/pull")
  577. @router.post("/api/pull/{url_idx}")
  578. async def pull_model(
  579. request: Request,
  580. form_data: ModelNameForm,
  581. url_idx: int = 0,
  582. user=Depends(get_admin_user),
  583. ):
  584. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  585. log.info(f"url: {url}")
  586. # Admin should be able to pull models from any source
  587. payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
  588. return await send_post_request(
  589. url=f"{url}/api/pull",
  590. payload=json.dumps(payload),
  591. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  592. user=user,
  593. )
  594. class PushModelForm(BaseModel):
  595. name: str
  596. insecure: Optional[bool] = None
  597. stream: Optional[bool] = None
  598. @router.delete("/api/push")
  599. @router.delete("/api/push/{url_idx}")
  600. async def push_model(
  601. request: Request,
  602. form_data: PushModelForm,
  603. url_idx: Optional[int] = None,
  604. user=Depends(get_admin_user),
  605. ):
  606. if url_idx is None:
  607. await get_all_models(request, user=user)
  608. models = request.app.state.OLLAMA_MODELS
  609. if form_data.name in models:
  610. url_idx = models[form_data.name]["urls"][0]
  611. else:
  612. raise HTTPException(
  613. status_code=400,
  614. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
  615. )
  616. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  617. log.debug(f"url: {url}")
  618. return await send_post_request(
  619. url=f"{url}/api/push",
  620. payload=form_data.model_dump_json(exclude_none=True).encode(),
  621. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  622. user=user,
  623. )
  624. class CreateModelForm(BaseModel):
  625. model: Optional[str] = None
  626. stream: Optional[bool] = None
  627. path: Optional[str] = None
  628. model_config = ConfigDict(extra="allow")
  629. @router.post("/api/create")
  630. @router.post("/api/create/{url_idx}")
  631. async def create_model(
  632. request: Request,
  633. form_data: CreateModelForm,
  634. url_idx: int = 0,
  635. user=Depends(get_admin_user),
  636. ):
  637. log.debug(f"form_data: {form_data}")
  638. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  639. return await send_post_request(
  640. url=f"{url}/api/create",
  641. payload=form_data.model_dump_json(exclude_none=True).encode(),
  642. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  643. user=user,
  644. )
  645. class CopyModelForm(BaseModel):
  646. source: str
  647. destination: str
  648. @router.post("/api/copy")
  649. @router.post("/api/copy/{url_idx}")
  650. async def copy_model(
  651. request: Request,
  652. form_data: CopyModelForm,
  653. url_idx: Optional[int] = None,
  654. user=Depends(get_admin_user),
  655. ):
  656. if url_idx is None:
  657. await get_all_models(request, user=user)
  658. models = request.app.state.OLLAMA_MODELS
  659. if form_data.source in models:
  660. url_idx = models[form_data.source]["urls"][0]
  661. else:
  662. raise HTTPException(
  663. status_code=400,
  664. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source),
  665. )
  666. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  667. key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  668. try:
  669. r = requests.request(
  670. method="POST",
  671. url=f"{url}/api/copy",
  672. headers={
  673. "Content-Type": "application/json",
  674. **({"Authorization": f"Bearer {key}"} if key else {}),
  675. **(
  676. {
  677. "X-OpenWebUI-User-Name": user.name,
  678. "X-OpenWebUI-User-Id": user.id,
  679. "X-OpenWebUI-User-Email": user.email,
  680. "X-OpenWebUI-User-Role": user.role,
  681. }
  682. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  683. else {}
  684. ),
  685. },
  686. data=form_data.model_dump_json(exclude_none=True).encode(),
  687. )
  688. r.raise_for_status()
  689. log.debug(f"r.text: {r.text}")
  690. return True
  691. except Exception as e:
  692. log.exception(e)
  693. detail = None
  694. if r is not None:
  695. try:
  696. res = r.json()
  697. if "error" in res:
  698. detail = f"Ollama: {res['error']}"
  699. except Exception:
  700. detail = f"Ollama: {e}"
  701. raise HTTPException(
  702. status_code=r.status_code if r else 500,
  703. detail=detail if detail else "Open WebUI: Server Connection Error",
  704. )
  705. @router.delete("/api/delete")
  706. @router.delete("/api/delete/{url_idx}")
  707. async def delete_model(
  708. request: Request,
  709. form_data: ModelNameForm,
  710. url_idx: Optional[int] = None,
  711. user=Depends(get_admin_user),
  712. ):
  713. if url_idx is None:
  714. await get_all_models(request, user=user)
  715. models = request.app.state.OLLAMA_MODELS
  716. if form_data.name in models:
  717. url_idx = models[form_data.name]["urls"][0]
  718. else:
  719. raise HTTPException(
  720. status_code=400,
  721. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
  722. )
  723. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  724. key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  725. try:
  726. r = requests.request(
  727. method="DELETE",
  728. url=f"{url}/api/delete",
  729. data=form_data.model_dump_json(exclude_none=True).encode(),
  730. headers={
  731. "Content-Type": "application/json",
  732. **({"Authorization": f"Bearer {key}"} if key else {}),
  733. **(
  734. {
  735. "X-OpenWebUI-User-Name": user.name,
  736. "X-OpenWebUI-User-Id": user.id,
  737. "X-OpenWebUI-User-Email": user.email,
  738. "X-OpenWebUI-User-Role": user.role,
  739. }
  740. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  741. else {}
  742. ),
  743. },
  744. )
  745. r.raise_for_status()
  746. log.debug(f"r.text: {r.text}")
  747. return True
  748. except Exception as e:
  749. log.exception(e)
  750. detail = None
  751. if r is not None:
  752. try:
  753. res = r.json()
  754. if "error" in res:
  755. detail = f"Ollama: {res['error']}"
  756. except Exception:
  757. detail = f"Ollama: {e}"
  758. raise HTTPException(
  759. status_code=r.status_code if r else 500,
  760. detail=detail if detail else "Open WebUI: Server Connection Error",
  761. )
  762. @router.post("/api/show")
  763. async def show_model_info(
  764. request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
  765. ):
  766. await get_all_models(request, user=user)
  767. models = request.app.state.OLLAMA_MODELS
  768. if form_data.name not in models:
  769. raise HTTPException(
  770. status_code=400,
  771. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
  772. )
  773. url_idx = random.choice(models[form_data.name]["urls"])
  774. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  775. key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  776. try:
  777. r = requests.request(
  778. method="POST",
  779. url=f"{url}/api/show",
  780. headers={
  781. "Content-Type": "application/json",
  782. **({"Authorization": f"Bearer {key}"} if key else {}),
  783. **(
  784. {
  785. "X-OpenWebUI-User-Name": user.name,
  786. "X-OpenWebUI-User-Id": user.id,
  787. "X-OpenWebUI-User-Email": user.email,
  788. "X-OpenWebUI-User-Role": user.role,
  789. }
  790. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  791. else {}
  792. ),
  793. },
  794. data=form_data.model_dump_json(exclude_none=True).encode(),
  795. )
  796. r.raise_for_status()
  797. return r.json()
  798. except Exception as e:
  799. log.exception(e)
  800. detail = None
  801. if r is not None:
  802. try:
  803. res = r.json()
  804. if "error" in res:
  805. detail = f"Ollama: {res['error']}"
  806. except Exception:
  807. detail = f"Ollama: {e}"
  808. raise HTTPException(
  809. status_code=r.status_code if r else 500,
  810. detail=detail if detail else "Open WebUI: Server Connection Error",
  811. )
  812. class GenerateEmbedForm(BaseModel):
  813. model: str
  814. input: list[str] | str
  815. truncate: Optional[bool] = None
  816. options: Optional[dict] = None
  817. keep_alive: Optional[Union[int, str]] = None
  818. @router.post("/api/embed")
  819. @router.post("/api/embed/{url_idx}")
  820. async def embed(
  821. request: Request,
  822. form_data: GenerateEmbedForm,
  823. url_idx: Optional[int] = None,
  824. user=Depends(get_verified_user),
  825. ):
  826. log.info(f"generate_ollama_batch_embeddings {form_data}")
  827. if url_idx is None:
  828. await get_all_models(request, user=user)
  829. models = request.app.state.OLLAMA_MODELS
  830. model = form_data.model
  831. if ":" not in model:
  832. model = f"{model}:latest"
  833. if model in models:
  834. url_idx = random.choice(models[model]["urls"])
  835. else:
  836. raise HTTPException(
  837. status_code=400,
  838. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
  839. )
  840. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  841. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  842. str(url_idx),
  843. request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
  844. )
  845. key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  846. prefix_id = api_config.get("prefix_id", None)
  847. if prefix_id:
  848. form_data.model = form_data.model.replace(f"{prefix_id}.", "")
  849. try:
  850. r = requests.request(
  851. method="POST",
  852. url=f"{url}/api/embed",
  853. headers={
  854. "Content-Type": "application/json",
  855. **({"Authorization": f"Bearer {key}"} if key else {}),
  856. **(
  857. {
  858. "X-OpenWebUI-User-Name": user.name,
  859. "X-OpenWebUI-User-Id": user.id,
  860. "X-OpenWebUI-User-Email": user.email,
  861. "X-OpenWebUI-User-Role": user.role,
  862. }
  863. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  864. else {}
  865. ),
  866. },
  867. data=form_data.model_dump_json(exclude_none=True).encode(),
  868. )
  869. r.raise_for_status()
  870. data = r.json()
  871. return data
  872. except Exception as e:
  873. log.exception(e)
  874. detail = None
  875. if r is not None:
  876. try:
  877. res = r.json()
  878. if "error" in res:
  879. detail = f"Ollama: {res['error']}"
  880. except Exception:
  881. detail = f"Ollama: {e}"
  882. raise HTTPException(
  883. status_code=r.status_code if r else 500,
  884. detail=detail if detail else "Open WebUI: Server Connection Error",
  885. )
  886. class GenerateEmbeddingsForm(BaseModel):
  887. model: str
  888. prompt: str
  889. options: Optional[dict] = None
  890. keep_alive: Optional[Union[int, str]] = None
  891. @router.post("/api/embeddings")
  892. @router.post("/api/embeddings/{url_idx}")
  893. async def embeddings(
  894. request: Request,
  895. form_data: GenerateEmbeddingsForm,
  896. url_idx: Optional[int] = None,
  897. user=Depends(get_verified_user),
  898. ):
  899. log.info(f"generate_ollama_embeddings {form_data}")
  900. if url_idx is None:
  901. await get_all_models(request, user=user)
  902. models = request.app.state.OLLAMA_MODELS
  903. model = form_data.model
  904. if ":" not in model:
  905. model = f"{model}:latest"
  906. if model in models:
  907. url_idx = random.choice(models[model]["urls"])
  908. else:
  909. raise HTTPException(
  910. status_code=400,
  911. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
  912. )
  913. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  914. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  915. str(url_idx),
  916. request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
  917. )
  918. key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  919. prefix_id = api_config.get("prefix_id", None)
  920. if prefix_id:
  921. form_data.model = form_data.model.replace(f"{prefix_id}.", "")
  922. try:
  923. r = requests.request(
  924. method="POST",
  925. url=f"{url}/api/embeddings",
  926. headers={
  927. "Content-Type": "application/json",
  928. **({"Authorization": f"Bearer {key}"} if key else {}),
  929. **(
  930. {
  931. "X-OpenWebUI-User-Name": user.name,
  932. "X-OpenWebUI-User-Id": user.id,
  933. "X-OpenWebUI-User-Email": user.email,
  934. "X-OpenWebUI-User-Role": user.role,
  935. }
  936. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  937. else {}
  938. ),
  939. },
  940. data=form_data.model_dump_json(exclude_none=True).encode(),
  941. )
  942. r.raise_for_status()
  943. data = r.json()
  944. return data
  945. except Exception as e:
  946. log.exception(e)
  947. detail = None
  948. if r is not None:
  949. try:
  950. res = r.json()
  951. if "error" in res:
  952. detail = f"Ollama: {res['error']}"
  953. except Exception:
  954. detail = f"Ollama: {e}"
  955. raise HTTPException(
  956. status_code=r.status_code if r else 500,
  957. detail=detail if detail else "Open WebUI: Server Connection Error",
  958. )
  959. class GenerateCompletionForm(BaseModel):
  960. model: str
  961. prompt: str
  962. suffix: Optional[str] = None
  963. images: Optional[list[str]] = None
  964. format: Optional[Union[dict, str]] = None
  965. options: Optional[dict] = None
  966. system: Optional[str] = None
  967. template: Optional[str] = None
  968. context: Optional[list[int]] = None
  969. stream: Optional[bool] = True
  970. raw: Optional[bool] = None
  971. keep_alive: Optional[Union[int, str]] = None
  972. @router.post("/api/generate")
  973. @router.post("/api/generate/{url_idx}")
  974. async def generate_completion(
  975. request: Request,
  976. form_data: GenerateCompletionForm,
  977. url_idx: Optional[int] = None,
  978. user=Depends(get_verified_user),
  979. ):
  980. if url_idx is None:
  981. await get_all_models(request, user=user)
  982. models = request.app.state.OLLAMA_MODELS
  983. model = form_data.model
  984. if ":" not in model:
  985. model = f"{model}:latest"
  986. if model in models:
  987. url_idx = random.choice(models[model]["urls"])
  988. else:
  989. raise HTTPException(
  990. status_code=400,
  991. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
  992. )
  993. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  994. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  995. str(url_idx),
  996. request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
  997. )
  998. prefix_id = api_config.get("prefix_id", None)
  999. if prefix_id:
  1000. form_data.model = form_data.model.replace(f"{prefix_id}.", "")
  1001. return await send_post_request(
  1002. url=f"{url}/api/generate",
  1003. payload=form_data.model_dump_json(exclude_none=True).encode(),
  1004. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  1005. user=user,
  1006. )
  1007. class ChatMessage(BaseModel):
  1008. role: str
  1009. content: Optional[str] = None
  1010. tool_calls: Optional[list[dict]] = None
  1011. images: Optional[list[str]] = None
  1012. @validator("content", pre=True)
  1013. @classmethod
  1014. def check_at_least_one_field(cls, field_value, values, **kwargs):
  1015. # Raise an error if both 'content' and 'tool_calls' are None
  1016. if field_value is None and (
  1017. "tool_calls" not in values or values["tool_calls"] is None
  1018. ):
  1019. raise ValueError(
  1020. "At least one of 'content' or 'tool_calls' must be provided"
  1021. )
  1022. return field_value
  1023. class GenerateChatCompletionForm(BaseModel):
  1024. model: str
  1025. messages: list[ChatMessage]
  1026. format: Optional[Union[dict, str]] = None
  1027. options: Optional[dict] = None
  1028. template: Optional[str] = None
  1029. stream: Optional[bool] = True
  1030. keep_alive: Optional[Union[int, str]] = None
  1031. tools: Optional[list[dict]] = None
  1032. async def get_ollama_url(request: Request, model: str, url_idx: Optional[int] = None):
  1033. if url_idx is None:
  1034. models = request.app.state.OLLAMA_MODELS
  1035. if model not in models:
  1036. raise HTTPException(
  1037. status_code=400,
  1038. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
  1039. )
  1040. url_idx = random.choice(models[model].get("urls", []))
  1041. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  1042. return url, url_idx
  1043. @router.post("/api/chat")
  1044. @router.post("/api/chat/{url_idx}")
  1045. async def generate_chat_completion(
  1046. request: Request,
  1047. form_data: dict,
  1048. url_idx: Optional[int] = None,
  1049. user=Depends(get_verified_user),
  1050. bypass_filter: Optional[bool] = False,
  1051. ):
  1052. if BYPASS_MODEL_ACCESS_CONTROL:
  1053. bypass_filter = True
  1054. metadata = form_data.pop("metadata", None)
  1055. try:
  1056. form_data = GenerateChatCompletionForm(**form_data)
  1057. except Exception as e:
  1058. log.exception(e)
  1059. raise HTTPException(
  1060. status_code=400,
  1061. detail=str(e),
  1062. )
  1063. payload = {**form_data.model_dump(exclude_none=True)}
  1064. if "metadata" in payload:
  1065. del payload["metadata"]
  1066. model_id = payload["model"]
  1067. model_info = Models.get_model_by_id(model_id)
  1068. if model_info:
  1069. if model_info.base_model_id:
  1070. payload["model"] = model_info.base_model_id
  1071. params = model_info.params.model_dump()
  1072. if params:
  1073. system = params.pop("system", None)
  1074. # Unlike OpenAI, Ollama does not support params directly in the body
  1075. payload["options"] = apply_model_params_to_body_ollama(
  1076. params, (payload.get("options", {}) or {})
  1077. )
  1078. payload = apply_model_system_prompt_to_body(system, payload, metadata, user)
  1079. # Check if user has access to the model
  1080. if not bypass_filter and user.role == "user":
  1081. if not (
  1082. user.id == model_info.user_id
  1083. or has_access(
  1084. user.id, type="read", access_control=model_info.access_control
  1085. )
  1086. ):
  1087. raise HTTPException(
  1088. status_code=403,
  1089. detail="Model not found",
  1090. )
  1091. elif not bypass_filter:
  1092. if user.role != "admin":
  1093. raise HTTPException(
  1094. status_code=403,
  1095. detail="Model not found",
  1096. )
  1097. if ":" not in payload["model"]:
  1098. payload["model"] = f"{payload['model']}:latest"
  1099. url, url_idx = await get_ollama_url(request, payload["model"], url_idx)
  1100. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  1101. str(url_idx),
  1102. request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
  1103. )
  1104. prefix_id = api_config.get("prefix_id", None)
  1105. if prefix_id:
  1106. payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
  1107. # payload["keep_alive"] = -1 # keep alive forever
  1108. return await send_post_request(
  1109. url=f"{url}/api/chat",
  1110. payload=json.dumps(payload),
  1111. stream=form_data.stream,
  1112. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  1113. content_type="application/x-ndjson",
  1114. user=user,
  1115. )
  1116. # TODO: we should update this part once Ollama supports other types
  1117. class OpenAIChatMessageContent(BaseModel):
  1118. type: str
  1119. model_config = ConfigDict(extra="allow")
  1120. class OpenAIChatMessage(BaseModel):
  1121. role: str
  1122. content: Union[Optional[str], list[OpenAIChatMessageContent]]
  1123. model_config = ConfigDict(extra="allow")
  1124. class OpenAIChatCompletionForm(BaseModel):
  1125. model: str
  1126. messages: list[OpenAIChatMessage]
  1127. model_config = ConfigDict(extra="allow")
  1128. class OpenAICompletionForm(BaseModel):
  1129. model: str
  1130. prompt: str
  1131. model_config = ConfigDict(extra="allow")
  1132. @router.post("/v1/completions")
  1133. @router.post("/v1/completions/{url_idx}")
  1134. async def generate_openai_completion(
  1135. request: Request,
  1136. form_data: dict,
  1137. url_idx: Optional[int] = None,
  1138. user=Depends(get_verified_user),
  1139. ):
  1140. try:
  1141. form_data = OpenAICompletionForm(**form_data)
  1142. except Exception as e:
  1143. log.exception(e)
  1144. raise HTTPException(
  1145. status_code=400,
  1146. detail=str(e),
  1147. )
  1148. payload = {**form_data.model_dump(exclude_none=True, exclude=["metadata"])}
  1149. if "metadata" in payload:
  1150. del payload["metadata"]
  1151. model_id = form_data.model
  1152. if ":" not in model_id:
  1153. model_id = f"{model_id}:latest"
  1154. model_info = Models.get_model_by_id(model_id)
  1155. if model_info:
  1156. if model_info.base_model_id:
  1157. payload["model"] = model_info.base_model_id
  1158. params = model_info.params.model_dump()
  1159. if params:
  1160. payload = apply_model_params_to_body_openai(params, payload)
  1161. # Check if user has access to the model
  1162. if user.role == "user":
  1163. if not (
  1164. user.id == model_info.user_id
  1165. or has_access(
  1166. user.id, type="read", access_control=model_info.access_control
  1167. )
  1168. ):
  1169. raise HTTPException(
  1170. status_code=403,
  1171. detail="Model not found",
  1172. )
  1173. else:
  1174. if user.role != "admin":
  1175. raise HTTPException(
  1176. status_code=403,
  1177. detail="Model not found",
  1178. )
  1179. if ":" not in payload["model"]:
  1180. payload["model"] = f"{payload['model']}:latest"
  1181. url, url_idx = await get_ollama_url(request, payload["model"], url_idx)
  1182. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  1183. str(url_idx),
  1184. request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
  1185. )
  1186. prefix_id = api_config.get("prefix_id", None)
  1187. if prefix_id:
  1188. payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
  1189. return await send_post_request(
  1190. url=f"{url}/v1/completions",
  1191. payload=json.dumps(payload),
  1192. stream=payload.get("stream", False),
  1193. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  1194. user=user,
  1195. )
  1196. @router.post("/v1/chat/completions")
  1197. @router.post("/v1/chat/completions/{url_idx}")
  1198. async def generate_openai_chat_completion(
  1199. request: Request,
  1200. form_data: dict,
  1201. url_idx: Optional[int] = None,
  1202. user=Depends(get_verified_user),
  1203. ):
  1204. metadata = form_data.pop("metadata", None)
  1205. try:
  1206. completion_form = OpenAIChatCompletionForm(**form_data)
  1207. except Exception as e:
  1208. log.exception(e)
  1209. raise HTTPException(
  1210. status_code=400,
  1211. detail=str(e),
  1212. )
  1213. payload = {**completion_form.model_dump(exclude_none=True, exclude=["metadata"])}
  1214. if "metadata" in payload:
  1215. del payload["metadata"]
  1216. model_id = completion_form.model
  1217. if ":" not in model_id:
  1218. model_id = f"{model_id}:latest"
  1219. model_info = Models.get_model_by_id(model_id)
  1220. if model_info:
  1221. if model_info.base_model_id:
  1222. payload["model"] = model_info.base_model_id
  1223. params = model_info.params.model_dump()
  1224. if params:
  1225. system = params.pop("system", None)
  1226. payload = apply_model_params_to_body_openai(params, payload)
  1227. payload = apply_model_system_prompt_to_body(system, payload, metadata, user)
  1228. # Check if user has access to the model
  1229. if user.role == "user":
  1230. if not (
  1231. user.id == model_info.user_id
  1232. or has_access(
  1233. user.id, type="read", access_control=model_info.access_control
  1234. )
  1235. ):
  1236. raise HTTPException(
  1237. status_code=403,
  1238. detail="Model not found",
  1239. )
  1240. else:
  1241. if user.role != "admin":
  1242. raise HTTPException(
  1243. status_code=403,
  1244. detail="Model not found",
  1245. )
  1246. if ":" not in payload["model"]:
  1247. payload["model"] = f"{payload['model']}:latest"
  1248. url, url_idx = await get_ollama_url(request, payload["model"], url_idx)
  1249. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  1250. str(url_idx),
  1251. request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
  1252. )
  1253. prefix_id = api_config.get("prefix_id", None)
  1254. if prefix_id:
  1255. payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
  1256. return await send_post_request(
  1257. url=f"{url}/v1/chat/completions",
  1258. payload=json.dumps(payload),
  1259. stream=payload.get("stream", False),
  1260. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  1261. user=user,
  1262. )
  1263. @router.get("/v1/models")
  1264. @router.get("/v1/models/{url_idx}")
  1265. async def get_openai_models(
  1266. request: Request,
  1267. url_idx: Optional[int] = None,
  1268. user=Depends(get_verified_user),
  1269. ):
  1270. models = []
  1271. if url_idx is None:
  1272. model_list = await get_all_models(request, user=user)
  1273. models = [
  1274. {
  1275. "id": model["model"],
  1276. "object": "model",
  1277. "created": int(time.time()),
  1278. "owned_by": "openai",
  1279. }
  1280. for model in model_list["models"]
  1281. ]
  1282. else:
  1283. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  1284. try:
  1285. r = requests.request(method="GET", url=f"{url}/api/tags")
  1286. r.raise_for_status()
  1287. model_list = r.json()
  1288. models = [
  1289. {
  1290. "id": model["model"],
  1291. "object": "model",
  1292. "created": int(time.time()),
  1293. "owned_by": "openai",
  1294. }
  1295. for model in models["models"]
  1296. ]
  1297. except Exception as e:
  1298. log.exception(e)
  1299. error_detail = "Open WebUI: Server Connection Error"
  1300. if r is not None:
  1301. try:
  1302. res = r.json()
  1303. if "error" in res:
  1304. error_detail = f"Ollama: {res['error']}"
  1305. except Exception:
  1306. error_detail = f"Ollama: {e}"
  1307. raise HTTPException(
  1308. status_code=r.status_code if r else 500,
  1309. detail=error_detail,
  1310. )
  1311. if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
  1312. # Filter models based on user access control
  1313. filtered_models = []
  1314. for model in models:
  1315. model_info = Models.get_model_by_id(model["id"])
  1316. if model_info:
  1317. if user.id == model_info.user_id or has_access(
  1318. user.id, type="read", access_control=model_info.access_control
  1319. ):
  1320. filtered_models.append(model)
  1321. models = filtered_models
  1322. return {
  1323. "data": models,
  1324. "object": "list",
  1325. }
  1326. class UrlForm(BaseModel):
  1327. url: str
  1328. class UploadBlobForm(BaseModel):
  1329. filename: str
  1330. def parse_huggingface_url(hf_url):
  1331. try:
  1332. # Parse the URL
  1333. parsed_url = urlparse(hf_url)
  1334. # Get the path and split it into components
  1335. path_components = parsed_url.path.split("/")
  1336. # Extract the desired output
  1337. model_file = path_components[-1]
  1338. return model_file
  1339. except ValueError:
  1340. return None
  1341. async def download_file_stream(
  1342. ollama_url, file_url, file_path, file_name, chunk_size=1024 * 1024
  1343. ):
  1344. done = False
  1345. if os.path.exists(file_path):
  1346. current_size = os.path.getsize(file_path)
  1347. else:
  1348. current_size = 0
  1349. headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {}
  1350. timeout = aiohttp.ClientTimeout(total=600) # Set the timeout
  1351. async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
  1352. async with session.get(
  1353. file_url, headers=headers, ssl=AIOHTTP_CLIENT_SESSION_SSL
  1354. ) as response:
  1355. total_size = int(response.headers.get("content-length", 0)) + current_size
  1356. with open(file_path, "ab+") as file:
  1357. async for data in response.content.iter_chunked(chunk_size):
  1358. current_size += len(data)
  1359. file.write(data)
  1360. done = current_size == total_size
  1361. progress = round((current_size / total_size) * 100, 2)
  1362. yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n'
  1363. if done:
  1364. file.seek(0)
  1365. chunk_size = 1024 * 1024 * 2
  1366. hashed = calculate_sha256(file, chunk_size)
  1367. file.seek(0)
  1368. url = f"{ollama_url}/api/blobs/sha256:{hashed}"
  1369. response = requests.post(url, data=file)
  1370. if response.ok:
  1371. res = {
  1372. "done": done,
  1373. "blob": f"sha256:{hashed}",
  1374. "name": file_name,
  1375. }
  1376. os.remove(file_path)
  1377. yield f"data: {json.dumps(res)}\n\n"
  1378. else:
  1379. raise "Ollama: Could not create blob, Please try again."
  1380. # url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
  1381. @router.post("/models/download")
  1382. @router.post("/models/download/{url_idx}")
  1383. async def download_model(
  1384. request: Request,
  1385. form_data: UrlForm,
  1386. url_idx: Optional[int] = None,
  1387. user=Depends(get_admin_user),
  1388. ):
  1389. allowed_hosts = ["https://huggingface.co/", "https://github.com/"]
  1390. if not any(form_data.url.startswith(host) for host in allowed_hosts):
  1391. raise HTTPException(
  1392. status_code=400,
  1393. detail="Invalid file_url. Only URLs from allowed hosts are permitted.",
  1394. )
  1395. if url_idx is None:
  1396. url_idx = 0
  1397. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  1398. file_name = parse_huggingface_url(form_data.url)
  1399. if file_name:
  1400. file_path = f"{UPLOAD_DIR}/{file_name}"
  1401. return StreamingResponse(
  1402. download_file_stream(url, form_data.url, file_path, file_name),
  1403. )
  1404. else:
  1405. return None
  1406. # TODO: Progress bar does not reflect size & duration of upload.
  1407. @router.post("/models/upload")
  1408. @router.post("/models/upload/{url_idx}")
  1409. async def upload_model(
  1410. request: Request,
  1411. file: UploadFile = File(...),
  1412. url_idx: Optional[int] = None,
  1413. user=Depends(get_admin_user),
  1414. ):
  1415. if url_idx is None:
  1416. url_idx = 0
  1417. ollama_url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  1418. filename = os.path.basename(file.filename)
  1419. file_path = os.path.join(UPLOAD_DIR, filename)
  1420. os.makedirs(UPLOAD_DIR, exist_ok=True)
  1421. # --- P1: save file locally ---
  1422. chunk_size = 1024 * 1024 * 2 # 2 MB chunks
  1423. with open(file_path, "wb") as out_f:
  1424. while True:
  1425. chunk = file.file.read(chunk_size)
  1426. # log.info(f"Chunk: {str(chunk)}") # DEBUG
  1427. if not chunk:
  1428. break
  1429. out_f.write(chunk)
  1430. async def file_process_stream():
  1431. nonlocal ollama_url
  1432. total_size = os.path.getsize(file_path)
  1433. log.info(f"Total Model Size: {str(total_size)}") # DEBUG
  1434. # --- P2: SSE progress + calculate sha256 hash ---
  1435. file_hash = calculate_sha256(file_path, chunk_size)
  1436. log.info(f"Model Hash: {str(file_hash)}") # DEBUG
  1437. try:
  1438. with open(file_path, "rb") as f:
  1439. bytes_read = 0
  1440. while chunk := f.read(chunk_size):
  1441. bytes_read += len(chunk)
  1442. progress = round(bytes_read / total_size * 100, 2)
  1443. data_msg = {
  1444. "progress": progress,
  1445. "total": total_size,
  1446. "completed": bytes_read,
  1447. }
  1448. yield f"data: {json.dumps(data_msg)}\n\n"
  1449. # --- P3: Upload to ollama /api/blobs ---
  1450. with open(file_path, "rb") as f:
  1451. url = f"{ollama_url}/api/blobs/sha256:{file_hash}"
  1452. response = requests.post(url, data=f)
  1453. if response.ok:
  1454. log.info(f"Uploaded to /api/blobs") # DEBUG
  1455. # Remove local file
  1456. os.remove(file_path)
  1457. # Create model in ollama
  1458. model_name, ext = os.path.splitext(filename)
  1459. log.info(f"Created Model: {model_name}") # DEBUG
  1460. create_payload = {
  1461. "model": model_name,
  1462. # Reference the file by its original name => the uploaded blob's digest
  1463. "files": {filename: f"sha256:{file_hash}"},
  1464. }
  1465. log.info(f"Model Payload: {create_payload}") # DEBUG
  1466. # Call ollama /api/create
  1467. # https://github.com/ollama/ollama/blob/main/docs/api.md#create-a-model
  1468. create_resp = requests.post(
  1469. url=f"{ollama_url}/api/create",
  1470. headers={"Content-Type": "application/json"},
  1471. data=json.dumps(create_payload),
  1472. )
  1473. if create_resp.ok:
  1474. log.info(f"API SUCCESS!") # DEBUG
  1475. done_msg = {
  1476. "done": True,
  1477. "blob": f"sha256:{file_hash}",
  1478. "name": filename,
  1479. "model_created": model_name,
  1480. }
  1481. yield f"data: {json.dumps(done_msg)}\n\n"
  1482. else:
  1483. raise Exception(
  1484. f"Failed to create model in Ollama. {create_resp.text}"
  1485. )
  1486. else:
  1487. raise Exception("Ollama: Could not create blob, Please try again.")
  1488. except Exception as e:
  1489. res = {"error": str(e)}
  1490. yield f"data: {json.dumps(res)}\n\n"
  1491. return StreamingResponse(file_process_stream(), media_type="text/event-stream")