ollama.py 52 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678
  1. # TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
  2. # Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
  3. # least connections, or least response time for better resource utilization and performance optimization.
  4. import asyncio
  5. import json
  6. import logging
  7. import os
  8. import random
  9. import re
  10. import time
  11. from typing import Optional, Union
  12. from urllib.parse import urlparse
  13. import aiohttp
  14. from aiocache import cached
  15. import requests
  16. from open_webui.models.users import UserModel
  17. from open_webui.env import (
  18. ENABLE_FORWARD_USER_INFO_HEADERS,
  19. )
  20. from fastapi import (
  21. Depends,
  22. FastAPI,
  23. File,
  24. HTTPException,
  25. Request,
  26. UploadFile,
  27. APIRouter,
  28. )
  29. from fastapi.middleware.cors import CORSMiddleware
  30. from fastapi.responses import StreamingResponse
  31. from pydantic import BaseModel, ConfigDict, validator
  32. from starlette.background import BackgroundTask
  33. from open_webui.models.models import Models
  34. from open_webui.utils.misc import (
  35. calculate_sha256,
  36. )
  37. from open_webui.utils.payload import (
  38. apply_model_params_to_body_ollama,
  39. apply_model_params_to_body_openai,
  40. apply_model_system_prompt_to_body,
  41. )
  42. from open_webui.utils.auth import get_admin_user, get_verified_user
  43. from open_webui.utils.access_control import has_access
  44. from open_webui.config import (
  45. UPLOAD_DIR,
  46. )
  47. from open_webui.env import (
  48. ENV,
  49. SRC_LOG_LEVELS,
  50. AIOHTTP_CLIENT_SESSION_SSL,
  51. AIOHTTP_CLIENT_TIMEOUT,
  52. AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
  53. BYPASS_MODEL_ACCESS_CONTROL,
  54. )
  55. from open_webui.constants import ERROR_MESSAGES
  56. log = logging.getLogger(__name__)
  57. log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
  58. ##########################################
  59. #
  60. # Utility functions
  61. #
  62. ##########################################
  63. async def send_get_request(url, key=None, user: UserModel = None):
  64. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
  65. try:
  66. async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
  67. async with session.get(
  68. url,
  69. headers={
  70. "Content-Type": "application/json",
  71. **({"Authorization": f"Bearer {key}"} if key else {}),
  72. **(
  73. {
  74. "X-OpenWebUI-User-Name": user.name,
  75. "X-OpenWebUI-User-Id": user.id,
  76. "X-OpenWebUI-User-Email": user.email,
  77. "X-OpenWebUI-User-Role": user.role,
  78. }
  79. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  80. else {}
  81. ),
  82. },
  83. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  84. ) as response:
  85. return await response.json()
  86. except Exception as e:
  87. # Handle connection error here
  88. log.error(f"Connection error: {e}")
  89. return None
  90. async def cleanup_response(
  91. response: Optional[aiohttp.ClientResponse],
  92. session: Optional[aiohttp.ClientSession],
  93. ):
  94. if response:
  95. response.close()
  96. if session:
  97. await session.close()
  98. async def send_post_request(
  99. url: str,
  100. payload: Union[str, bytes],
  101. stream: bool = True,
  102. key: Optional[str] = None,
  103. content_type: Optional[str] = None,
  104. user: UserModel = None,
  105. ):
  106. r = None
  107. try:
  108. session = aiohttp.ClientSession(
  109. trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
  110. )
  111. r = await session.post(
  112. url,
  113. data=payload,
  114. headers={
  115. "Content-Type": "application/json",
  116. **({"Authorization": f"Bearer {key}"} if key else {}),
  117. **(
  118. {
  119. "X-OpenWebUI-User-Name": user.name,
  120. "X-OpenWebUI-User-Id": user.id,
  121. "X-OpenWebUI-User-Email": user.email,
  122. "X-OpenWebUI-User-Role": user.role,
  123. }
  124. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  125. else {}
  126. ),
  127. },
  128. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  129. )
  130. r.raise_for_status()
  131. if stream:
  132. response_headers = dict(r.headers)
  133. if content_type:
  134. response_headers["Content-Type"] = content_type
  135. return StreamingResponse(
  136. r.content,
  137. status_code=r.status,
  138. headers=response_headers,
  139. background=BackgroundTask(
  140. cleanup_response, response=r, session=session
  141. ),
  142. )
  143. else:
  144. res = await r.json()
  145. await cleanup_response(r, session)
  146. return res
  147. except Exception as e:
  148. detail = None
  149. if r is not None:
  150. try:
  151. res = await r.json()
  152. if "error" in res:
  153. detail = f"Ollama: {res.get('error', 'Unknown error')}"
  154. except Exception:
  155. detail = f"Ollama: {e}"
  156. raise HTTPException(
  157. status_code=r.status if r else 500,
  158. detail=detail if detail else "Open WebUI: Server Connection Error",
  159. )
  160. def get_api_key(idx, url, configs):
  161. parsed_url = urlparse(url)
  162. base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
  163. return configs.get(str(idx), configs.get(base_url, {})).get(
  164. "key", None
  165. ) # Legacy support
  166. ##########################################
  167. #
  168. # API routes
  169. #
  170. ##########################################
  171. router = APIRouter()
  172. @router.head("/")
  173. @router.get("/")
  174. async def get_status():
  175. return {"status": True}
  176. class ConnectionVerificationForm(BaseModel):
  177. url: str
  178. key: Optional[str] = None
  179. @router.post("/verify")
  180. async def verify_connection(
  181. form_data: ConnectionVerificationForm, user=Depends(get_admin_user)
  182. ):
  183. url = form_data.url
  184. key = form_data.key
  185. async with aiohttp.ClientSession(
  186. trust_env=True,
  187. timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
  188. ) as session:
  189. try:
  190. async with session.get(
  191. f"{url}/api/version",
  192. headers={
  193. **({"Authorization": f"Bearer {key}"} if key else {}),
  194. **(
  195. {
  196. "X-OpenWebUI-User-Name": user.name,
  197. "X-OpenWebUI-User-Id": user.id,
  198. "X-OpenWebUI-User-Email": user.email,
  199. "X-OpenWebUI-User-Role": user.role,
  200. }
  201. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  202. else {}
  203. ),
  204. },
  205. ssl=AIOHTTP_CLIENT_SESSION_SSL,
  206. ) as r:
  207. if r.status != 200:
  208. detail = f"HTTP Error: {r.status}"
  209. res = await r.json()
  210. if "error" in res:
  211. detail = f"External Error: {res['error']}"
  212. raise Exception(detail)
  213. data = await r.json()
  214. return data
  215. except aiohttp.ClientError as e:
  216. log.exception(f"Client error: {str(e)}")
  217. raise HTTPException(
  218. status_code=500, detail="Open WebUI: Server Connection Error"
  219. )
  220. except Exception as e:
  221. log.exception(f"Unexpected error: {e}")
  222. error_detail = f"Unexpected error: {str(e)}"
  223. raise HTTPException(status_code=500, detail=error_detail)
  224. @router.get("/config")
  225. async def get_config(request: Request, user=Depends(get_admin_user)):
  226. return {
  227. "ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API,
  228. "OLLAMA_BASE_URLS": request.app.state.config.OLLAMA_BASE_URLS,
  229. "OLLAMA_API_CONFIGS": request.app.state.config.OLLAMA_API_CONFIGS,
  230. }
  231. class OllamaConfigForm(BaseModel):
  232. ENABLE_OLLAMA_API: Optional[bool] = None
  233. OLLAMA_BASE_URLS: list[str]
  234. OLLAMA_API_CONFIGS: dict
  235. @router.post("/config/update")
  236. async def update_config(
  237. request: Request, form_data: OllamaConfigForm, user=Depends(get_admin_user)
  238. ):
  239. request.app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API
  240. request.app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS
  241. request.app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS
  242. # Remove the API configs that are not in the API URLS
  243. keys = list(map(str, range(len(request.app.state.config.OLLAMA_BASE_URLS))))
  244. request.app.state.config.OLLAMA_API_CONFIGS = {
  245. key: value
  246. for key, value in request.app.state.config.OLLAMA_API_CONFIGS.items()
  247. if key in keys
  248. }
  249. return {
  250. "ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API,
  251. "OLLAMA_BASE_URLS": request.app.state.config.OLLAMA_BASE_URLS,
  252. "OLLAMA_API_CONFIGS": request.app.state.config.OLLAMA_API_CONFIGS,
  253. }
  254. @cached(ttl=1)
  255. async def get_all_models(request: Request, user: UserModel = None):
  256. log.info("get_all_models()")
  257. if request.app.state.config.ENABLE_OLLAMA_API:
  258. request_tasks = []
  259. for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
  260. if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and (
  261. url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support
  262. ):
  263. request_tasks.append(send_get_request(f"{url}/api/tags", user=user))
  264. else:
  265. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  266. str(idx),
  267. request.app.state.config.OLLAMA_API_CONFIGS.get(
  268. url, {}
  269. ), # Legacy support
  270. )
  271. enable = api_config.get("enable", True)
  272. key = api_config.get("key", None)
  273. if enable:
  274. request_tasks.append(
  275. send_get_request(f"{url}/api/tags", key, user=user)
  276. )
  277. else:
  278. request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
  279. responses = await asyncio.gather(*request_tasks)
  280. for idx, response in enumerate(responses):
  281. if response:
  282. url = request.app.state.config.OLLAMA_BASE_URLS[idx]
  283. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  284. str(idx),
  285. request.app.state.config.OLLAMA_API_CONFIGS.get(
  286. url, {}
  287. ), # Legacy support
  288. )
  289. connection_type = api_config.get("connection_type", "local")
  290. prefix_id = api_config.get("prefix_id", None)
  291. tags = api_config.get("tags", [])
  292. model_ids = api_config.get("model_ids", [])
  293. if len(model_ids) != 0 and "models" in response:
  294. response["models"] = list(
  295. filter(
  296. lambda model: model["model"] in model_ids,
  297. response["models"],
  298. )
  299. )
  300. for model in response.get("models", []):
  301. if prefix_id:
  302. model["model"] = f"{prefix_id}.{model['model']}"
  303. if tags:
  304. model["tags"] = tags
  305. if connection_type:
  306. model["connection_type"] = connection_type
  307. def merge_models_lists(model_lists):
  308. merged_models = {}
  309. for idx, model_list in enumerate(model_lists):
  310. if model_list is not None:
  311. for model in model_list:
  312. id = model["model"]
  313. if id not in merged_models:
  314. model["urls"] = [idx]
  315. merged_models[id] = model
  316. else:
  317. merged_models[id]["urls"].append(idx)
  318. return list(merged_models.values())
  319. models = {
  320. "models": merge_models_lists(
  321. map(
  322. lambda response: response.get("models", []) if response else None,
  323. responses,
  324. )
  325. )
  326. }
  327. else:
  328. models = {"models": []}
  329. request.app.state.OLLAMA_MODELS = {
  330. model["model"]: model for model in models["models"]
  331. }
  332. return models
  333. async def get_filtered_models(models, user):
  334. # Filter models based on user access control
  335. filtered_models = []
  336. for model in models.get("models", []):
  337. model_info = Models.get_model_by_id(model["model"])
  338. if model_info:
  339. if user.id == model_info.user_id or has_access(
  340. user.id, type="read", access_control=model_info.access_control
  341. ):
  342. filtered_models.append(model)
  343. return filtered_models
  344. @router.get("/api/tags")
  345. @router.get("/api/tags/{url_idx}")
  346. async def get_ollama_tags(
  347. request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user)
  348. ):
  349. models = []
  350. if url_idx is None:
  351. models = await get_all_models(request, user=user)
  352. else:
  353. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  354. key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  355. r = None
  356. try:
  357. r = requests.request(
  358. method="GET",
  359. url=f"{url}/api/tags",
  360. headers={
  361. **({"Authorization": f"Bearer {key}"} if key else {}),
  362. **(
  363. {
  364. "X-OpenWebUI-User-Name": user.name,
  365. "X-OpenWebUI-User-Id": user.id,
  366. "X-OpenWebUI-User-Email": user.email,
  367. "X-OpenWebUI-User-Role": user.role,
  368. }
  369. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  370. else {}
  371. ),
  372. },
  373. )
  374. r.raise_for_status()
  375. models = r.json()
  376. except Exception as e:
  377. log.exception(e)
  378. detail = None
  379. if r is not None:
  380. try:
  381. res = r.json()
  382. if "error" in res:
  383. detail = f"Ollama: {res['error']}"
  384. except Exception:
  385. detail = f"Ollama: {e}"
  386. raise HTTPException(
  387. status_code=r.status_code if r else 500,
  388. detail=detail if detail else "Open WebUI: Server Connection Error",
  389. )
  390. if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
  391. models["models"] = await get_filtered_models(models, user)
  392. return models
  393. @router.get("/api/version")
  394. @router.get("/api/version/{url_idx}")
  395. async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
  396. if request.app.state.config.ENABLE_OLLAMA_API:
  397. if url_idx is None:
  398. # returns lowest version
  399. request_tasks = []
  400. for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
  401. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  402. str(idx),
  403. request.app.state.config.OLLAMA_API_CONFIGS.get(
  404. url, {}
  405. ), # Legacy support
  406. )
  407. enable = api_config.get("enable", True)
  408. key = api_config.get("key", None)
  409. if enable:
  410. request_tasks.append(
  411. send_get_request(
  412. f"{url}/api/version",
  413. key,
  414. )
  415. )
  416. responses = await asyncio.gather(*request_tasks)
  417. responses = list(filter(lambda x: x is not None, responses))
  418. if len(responses) > 0:
  419. lowest_version = min(
  420. responses,
  421. key=lambda x: tuple(
  422. map(int, re.sub(r"^v|-.*", "", x["version"]).split("."))
  423. ),
  424. )
  425. return {"version": lowest_version["version"]}
  426. else:
  427. raise HTTPException(
  428. status_code=500,
  429. detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
  430. )
  431. else:
  432. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  433. r = None
  434. try:
  435. r = requests.request(method="GET", url=f"{url}/api/version")
  436. r.raise_for_status()
  437. return r.json()
  438. except Exception as e:
  439. log.exception(e)
  440. detail = None
  441. if r is not None:
  442. try:
  443. res = r.json()
  444. if "error" in res:
  445. detail = f"Ollama: {res['error']}"
  446. except Exception:
  447. detail = f"Ollama: {e}"
  448. raise HTTPException(
  449. status_code=r.status_code if r else 500,
  450. detail=detail if detail else "Open WebUI: Server Connection Error",
  451. )
  452. else:
  453. return {"version": False}
  454. @router.get("/api/ps")
  455. async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_user)):
  456. """
  457. List models that are currently loaded into Ollama memory, and which node they are loaded on.
  458. """
  459. if request.app.state.config.ENABLE_OLLAMA_API:
  460. request_tasks = [
  461. send_get_request(
  462. f"{url}/api/ps",
  463. request.app.state.config.OLLAMA_API_CONFIGS.get(
  464. str(idx),
  465. request.app.state.config.OLLAMA_API_CONFIGS.get(
  466. url, {}
  467. ), # Legacy support
  468. ).get("key", None),
  469. user=user,
  470. )
  471. for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS)
  472. ]
  473. responses = await asyncio.gather(*request_tasks)
  474. return dict(zip(request.app.state.config.OLLAMA_BASE_URLS, responses))
  475. else:
  476. return {}
  477. class ModelNameForm(BaseModel):
  478. name: str
  479. @router.post("/api/pull")
  480. @router.post("/api/pull/{url_idx}")
  481. async def pull_model(
  482. request: Request,
  483. form_data: ModelNameForm,
  484. url_idx: int = 0,
  485. user=Depends(get_admin_user),
  486. ):
  487. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  488. log.info(f"url: {url}")
  489. # Admin should be able to pull models from any source
  490. payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
  491. return await send_post_request(
  492. url=f"{url}/api/pull",
  493. payload=json.dumps(payload),
  494. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  495. user=user,
  496. )
  497. class PushModelForm(BaseModel):
  498. name: str
  499. insecure: Optional[bool] = None
  500. stream: Optional[bool] = None
  501. @router.delete("/api/push")
  502. @router.delete("/api/push/{url_idx}")
  503. async def push_model(
  504. request: Request,
  505. form_data: PushModelForm,
  506. url_idx: Optional[int] = None,
  507. user=Depends(get_admin_user),
  508. ):
  509. if url_idx is None:
  510. await get_all_models(request, user=user)
  511. models = request.app.state.OLLAMA_MODELS
  512. if form_data.name in models:
  513. url_idx = models[form_data.name]["urls"][0]
  514. else:
  515. raise HTTPException(
  516. status_code=400,
  517. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
  518. )
  519. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  520. log.debug(f"url: {url}")
  521. return await send_post_request(
  522. url=f"{url}/api/push",
  523. payload=form_data.model_dump_json(exclude_none=True).encode(),
  524. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  525. user=user,
  526. )
  527. class CreateModelForm(BaseModel):
  528. model: Optional[str] = None
  529. stream: Optional[bool] = None
  530. path: Optional[str] = None
  531. model_config = ConfigDict(extra="allow")
  532. @router.post("/api/create")
  533. @router.post("/api/create/{url_idx}")
  534. async def create_model(
  535. request: Request,
  536. form_data: CreateModelForm,
  537. url_idx: int = 0,
  538. user=Depends(get_admin_user),
  539. ):
  540. log.debug(f"form_data: {form_data}")
  541. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  542. return await send_post_request(
  543. url=f"{url}/api/create",
  544. payload=form_data.model_dump_json(exclude_none=True).encode(),
  545. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  546. user=user,
  547. )
  548. class CopyModelForm(BaseModel):
  549. source: str
  550. destination: str
  551. @router.post("/api/copy")
  552. @router.post("/api/copy/{url_idx}")
  553. async def copy_model(
  554. request: Request,
  555. form_data: CopyModelForm,
  556. url_idx: Optional[int] = None,
  557. user=Depends(get_admin_user),
  558. ):
  559. if url_idx is None:
  560. await get_all_models(request, user=user)
  561. models = request.app.state.OLLAMA_MODELS
  562. if form_data.source in models:
  563. url_idx = models[form_data.source]["urls"][0]
  564. else:
  565. raise HTTPException(
  566. status_code=400,
  567. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source),
  568. )
  569. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  570. key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  571. try:
  572. r = requests.request(
  573. method="POST",
  574. url=f"{url}/api/copy",
  575. headers={
  576. "Content-Type": "application/json",
  577. **({"Authorization": f"Bearer {key}"} if key else {}),
  578. **(
  579. {
  580. "X-OpenWebUI-User-Name": user.name,
  581. "X-OpenWebUI-User-Id": user.id,
  582. "X-OpenWebUI-User-Email": user.email,
  583. "X-OpenWebUI-User-Role": user.role,
  584. }
  585. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  586. else {}
  587. ),
  588. },
  589. data=form_data.model_dump_json(exclude_none=True).encode(),
  590. )
  591. r.raise_for_status()
  592. log.debug(f"r.text: {r.text}")
  593. return True
  594. except Exception as e:
  595. log.exception(e)
  596. detail = None
  597. if r is not None:
  598. try:
  599. res = r.json()
  600. if "error" in res:
  601. detail = f"Ollama: {res['error']}"
  602. except Exception:
  603. detail = f"Ollama: {e}"
  604. raise HTTPException(
  605. status_code=r.status_code if r else 500,
  606. detail=detail if detail else "Open WebUI: Server Connection Error",
  607. )
  608. @router.delete("/api/delete")
  609. @router.delete("/api/delete/{url_idx}")
  610. async def delete_model(
  611. request: Request,
  612. form_data: ModelNameForm,
  613. url_idx: Optional[int] = None,
  614. user=Depends(get_admin_user),
  615. ):
  616. if url_idx is None:
  617. await get_all_models(request, user=user)
  618. models = request.app.state.OLLAMA_MODELS
  619. if form_data.name in models:
  620. url_idx = models[form_data.name]["urls"][0]
  621. else:
  622. raise HTTPException(
  623. status_code=400,
  624. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
  625. )
  626. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  627. key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  628. try:
  629. r = requests.request(
  630. method="DELETE",
  631. url=f"{url}/api/delete",
  632. data=form_data.model_dump_json(exclude_none=True).encode(),
  633. headers={
  634. "Content-Type": "application/json",
  635. **({"Authorization": f"Bearer {key}"} if key else {}),
  636. **(
  637. {
  638. "X-OpenWebUI-User-Name": user.name,
  639. "X-OpenWebUI-User-Id": user.id,
  640. "X-OpenWebUI-User-Email": user.email,
  641. "X-OpenWebUI-User-Role": user.role,
  642. }
  643. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  644. else {}
  645. ),
  646. },
  647. )
  648. r.raise_for_status()
  649. log.debug(f"r.text: {r.text}")
  650. return True
  651. except Exception as e:
  652. log.exception(e)
  653. detail = None
  654. if r is not None:
  655. try:
  656. res = r.json()
  657. if "error" in res:
  658. detail = f"Ollama: {res['error']}"
  659. except Exception:
  660. detail = f"Ollama: {e}"
  661. raise HTTPException(
  662. status_code=r.status_code if r else 500,
  663. detail=detail if detail else "Open WebUI: Server Connection Error",
  664. )
  665. @router.post("/api/show")
  666. async def show_model_info(
  667. request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
  668. ):
  669. await get_all_models(request, user=user)
  670. models = request.app.state.OLLAMA_MODELS
  671. if form_data.name not in models:
  672. raise HTTPException(
  673. status_code=400,
  674. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
  675. )
  676. url_idx = random.choice(models[form_data.name]["urls"])
  677. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  678. key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  679. try:
  680. r = requests.request(
  681. method="POST",
  682. url=f"{url}/api/show",
  683. headers={
  684. "Content-Type": "application/json",
  685. **({"Authorization": f"Bearer {key}"} if key else {}),
  686. **(
  687. {
  688. "X-OpenWebUI-User-Name": user.name,
  689. "X-OpenWebUI-User-Id": user.id,
  690. "X-OpenWebUI-User-Email": user.email,
  691. "X-OpenWebUI-User-Role": user.role,
  692. }
  693. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  694. else {}
  695. ),
  696. },
  697. data=form_data.model_dump_json(exclude_none=True).encode(),
  698. )
  699. r.raise_for_status()
  700. return r.json()
  701. except Exception as e:
  702. log.exception(e)
  703. detail = None
  704. if r is not None:
  705. try:
  706. res = r.json()
  707. if "error" in res:
  708. detail = f"Ollama: {res['error']}"
  709. except Exception:
  710. detail = f"Ollama: {e}"
  711. raise HTTPException(
  712. status_code=r.status_code if r else 500,
  713. detail=detail if detail else "Open WebUI: Server Connection Error",
  714. )
  715. class GenerateEmbedForm(BaseModel):
  716. model: str
  717. input: list[str] | str
  718. truncate: Optional[bool] = None
  719. options: Optional[dict] = None
  720. keep_alive: Optional[Union[int, str]] = None
  721. @router.post("/api/embed")
  722. @router.post("/api/embed/{url_idx}")
  723. async def embed(
  724. request: Request,
  725. form_data: GenerateEmbedForm,
  726. url_idx: Optional[int] = None,
  727. user=Depends(get_verified_user),
  728. ):
  729. log.info(f"generate_ollama_batch_embeddings {form_data}")
  730. if url_idx is None:
  731. await get_all_models(request, user=user)
  732. models = request.app.state.OLLAMA_MODELS
  733. model = form_data.model
  734. if ":" not in model:
  735. model = f"{model}:latest"
  736. if model in models:
  737. url_idx = random.choice(models[model]["urls"])
  738. else:
  739. raise HTTPException(
  740. status_code=400,
  741. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
  742. )
  743. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  744. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  745. str(url_idx),
  746. request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
  747. )
  748. key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  749. prefix_id = api_config.get("prefix_id", None)
  750. if prefix_id:
  751. form_data.model = form_data.model.replace(f"{prefix_id}.", "")
  752. try:
  753. r = requests.request(
  754. method="POST",
  755. url=f"{url}/api/embed",
  756. headers={
  757. "Content-Type": "application/json",
  758. **({"Authorization": f"Bearer {key}"} if key else {}),
  759. **(
  760. {
  761. "X-OpenWebUI-User-Name": user.name,
  762. "X-OpenWebUI-User-Id": user.id,
  763. "X-OpenWebUI-User-Email": user.email,
  764. "X-OpenWebUI-User-Role": user.role,
  765. }
  766. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  767. else {}
  768. ),
  769. },
  770. data=form_data.model_dump_json(exclude_none=True).encode(),
  771. )
  772. r.raise_for_status()
  773. data = r.json()
  774. return data
  775. except Exception as e:
  776. log.exception(e)
  777. detail = None
  778. if r is not None:
  779. try:
  780. res = r.json()
  781. if "error" in res:
  782. detail = f"Ollama: {res['error']}"
  783. except Exception:
  784. detail = f"Ollama: {e}"
  785. raise HTTPException(
  786. status_code=r.status_code if r else 500,
  787. detail=detail if detail else "Open WebUI: Server Connection Error",
  788. )
  789. class GenerateEmbeddingsForm(BaseModel):
  790. model: str
  791. prompt: str
  792. options: Optional[dict] = None
  793. keep_alive: Optional[Union[int, str]] = None
  794. @router.post("/api/embeddings")
  795. @router.post("/api/embeddings/{url_idx}")
  796. async def embeddings(
  797. request: Request,
  798. form_data: GenerateEmbeddingsForm,
  799. url_idx: Optional[int] = None,
  800. user=Depends(get_verified_user),
  801. ):
  802. log.info(f"generate_ollama_embeddings {form_data}")
  803. if url_idx is None:
  804. await get_all_models(request, user=user)
  805. models = request.app.state.OLLAMA_MODELS
  806. model = form_data.model
  807. if ":" not in model:
  808. model = f"{model}:latest"
  809. if model in models:
  810. url_idx = random.choice(models[model]["urls"])
  811. else:
  812. raise HTTPException(
  813. status_code=400,
  814. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
  815. )
  816. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  817. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  818. str(url_idx),
  819. request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
  820. )
  821. key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  822. prefix_id = api_config.get("prefix_id", None)
  823. if prefix_id:
  824. form_data.model = form_data.model.replace(f"{prefix_id}.", "")
  825. try:
  826. r = requests.request(
  827. method="POST",
  828. url=f"{url}/api/embeddings",
  829. headers={
  830. "Content-Type": "application/json",
  831. **({"Authorization": f"Bearer {key}"} if key else {}),
  832. **(
  833. {
  834. "X-OpenWebUI-User-Name": user.name,
  835. "X-OpenWebUI-User-Id": user.id,
  836. "X-OpenWebUI-User-Email": user.email,
  837. "X-OpenWebUI-User-Role": user.role,
  838. }
  839. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  840. else {}
  841. ),
  842. },
  843. data=form_data.model_dump_json(exclude_none=True).encode(),
  844. )
  845. r.raise_for_status()
  846. data = r.json()
  847. return data
  848. except Exception as e:
  849. log.exception(e)
  850. detail = None
  851. if r is not None:
  852. try:
  853. res = r.json()
  854. if "error" in res:
  855. detail = f"Ollama: {res['error']}"
  856. except Exception:
  857. detail = f"Ollama: {e}"
  858. raise HTTPException(
  859. status_code=r.status_code if r else 500,
  860. detail=detail if detail else "Open WebUI: Server Connection Error",
  861. )
  862. class GenerateCompletionForm(BaseModel):
  863. model: str
  864. prompt: str
  865. suffix: Optional[str] = None
  866. images: Optional[list[str]] = None
  867. format: Optional[Union[dict, str]] = None
  868. options: Optional[dict] = None
  869. system: Optional[str] = None
  870. template: Optional[str] = None
  871. context: Optional[list[int]] = None
  872. stream: Optional[bool] = True
  873. raw: Optional[bool] = None
  874. keep_alive: Optional[Union[int, str]] = None
  875. @router.post("/api/generate")
  876. @router.post("/api/generate/{url_idx}")
  877. async def generate_completion(
  878. request: Request,
  879. form_data: GenerateCompletionForm,
  880. url_idx: Optional[int] = None,
  881. user=Depends(get_verified_user),
  882. ):
  883. if url_idx is None:
  884. await get_all_models(request, user=user)
  885. models = request.app.state.OLLAMA_MODELS
  886. model = form_data.model
  887. if ":" not in model:
  888. model = f"{model}:latest"
  889. if model in models:
  890. url_idx = random.choice(models[model]["urls"])
  891. else:
  892. raise HTTPException(
  893. status_code=400,
  894. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
  895. )
  896. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  897. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  898. str(url_idx),
  899. request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
  900. )
  901. prefix_id = api_config.get("prefix_id", None)
  902. if prefix_id:
  903. form_data.model = form_data.model.replace(f"{prefix_id}.", "")
  904. return await send_post_request(
  905. url=f"{url}/api/generate",
  906. payload=form_data.model_dump_json(exclude_none=True).encode(),
  907. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  908. user=user,
  909. )
  910. class ChatMessage(BaseModel):
  911. role: str
  912. content: Optional[str] = None
  913. tool_calls: Optional[list[dict]] = None
  914. images: Optional[list[str]] = None
  915. @validator("content", pre=True)
  916. @classmethod
  917. def check_at_least_one_field(cls, field_value, values, **kwargs):
  918. # Raise an error if both 'content' and 'tool_calls' are None
  919. if field_value is None and (
  920. "tool_calls" not in values or values["tool_calls"] is None
  921. ):
  922. raise ValueError(
  923. "At least one of 'content' or 'tool_calls' must be provided"
  924. )
  925. return field_value
  926. class GenerateChatCompletionForm(BaseModel):
  927. model: str
  928. messages: list[ChatMessage]
  929. format: Optional[Union[dict, str]] = None
  930. options: Optional[dict] = None
  931. template: Optional[str] = None
  932. stream: Optional[bool] = True
  933. keep_alive: Optional[Union[int, str]] = None
  934. tools: Optional[list[dict]] = None
  935. async def get_ollama_url(request: Request, model: str, url_idx: Optional[int] = None):
  936. if url_idx is None:
  937. models = request.app.state.OLLAMA_MODELS
  938. if model not in models:
  939. raise HTTPException(
  940. status_code=400,
  941. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
  942. )
  943. url_idx = random.choice(models[model].get("urls", []))
  944. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  945. return url, url_idx
  946. @router.post("/api/chat")
  947. @router.post("/api/chat/{url_idx}")
  948. async def generate_chat_completion(
  949. request: Request,
  950. form_data: dict,
  951. url_idx: Optional[int] = None,
  952. user=Depends(get_verified_user),
  953. bypass_filter: Optional[bool] = False,
  954. ):
  955. if BYPASS_MODEL_ACCESS_CONTROL:
  956. bypass_filter = True
  957. metadata = form_data.pop("metadata", None)
  958. try:
  959. form_data = GenerateChatCompletionForm(**form_data)
  960. except Exception as e:
  961. log.exception(e)
  962. raise HTTPException(
  963. status_code=400,
  964. detail=str(e),
  965. )
  966. payload = {**form_data.model_dump(exclude_none=True)}
  967. if "metadata" in payload:
  968. del payload["metadata"]
  969. model_id = payload["model"]
  970. model_info = Models.get_model_by_id(model_id)
  971. if model_info:
  972. if model_info.base_model_id:
  973. payload["model"] = model_info.base_model_id
  974. params = model_info.params.model_dump()
  975. if params:
  976. if payload.get("options") is None:
  977. payload["options"] = {}
  978. payload["options"] = apply_model_params_to_body_ollama(
  979. params, payload["options"]
  980. )
  981. payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
  982. # Check if user has access to the model
  983. if not bypass_filter and user.role == "user":
  984. if not (
  985. user.id == model_info.user_id
  986. or has_access(
  987. user.id, type="read", access_control=model_info.access_control
  988. )
  989. ):
  990. raise HTTPException(
  991. status_code=403,
  992. detail="Model not found",
  993. )
  994. elif not bypass_filter:
  995. if user.role != "admin":
  996. raise HTTPException(
  997. status_code=403,
  998. detail="Model not found",
  999. )
  1000. if ":" not in payload["model"]:
  1001. payload["model"] = f"{payload['model']}:latest"
  1002. url, url_idx = await get_ollama_url(request, payload["model"], url_idx)
  1003. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  1004. str(url_idx),
  1005. request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
  1006. )
  1007. prefix_id = api_config.get("prefix_id", None)
  1008. if prefix_id:
  1009. payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
  1010. # payload["keep_alive"] = -1 # keep alive forever
  1011. return await send_post_request(
  1012. url=f"{url}/api/chat",
  1013. payload=json.dumps(payload),
  1014. stream=form_data.stream,
  1015. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  1016. content_type="application/x-ndjson",
  1017. user=user,
  1018. )
  1019. # TODO: we should update this part once Ollama supports other types
  1020. class OpenAIChatMessageContent(BaseModel):
  1021. type: str
  1022. model_config = ConfigDict(extra="allow")
  1023. class OpenAIChatMessage(BaseModel):
  1024. role: str
  1025. content: Union[Optional[str], list[OpenAIChatMessageContent]]
  1026. model_config = ConfigDict(extra="allow")
  1027. class OpenAIChatCompletionForm(BaseModel):
  1028. model: str
  1029. messages: list[OpenAIChatMessage]
  1030. model_config = ConfigDict(extra="allow")
  1031. class OpenAICompletionForm(BaseModel):
  1032. model: str
  1033. prompt: str
  1034. model_config = ConfigDict(extra="allow")
  1035. @router.post("/v1/completions")
  1036. @router.post("/v1/completions/{url_idx}")
  1037. async def generate_openai_completion(
  1038. request: Request,
  1039. form_data: dict,
  1040. url_idx: Optional[int] = None,
  1041. user=Depends(get_verified_user),
  1042. ):
  1043. try:
  1044. form_data = OpenAICompletionForm(**form_data)
  1045. except Exception as e:
  1046. log.exception(e)
  1047. raise HTTPException(
  1048. status_code=400,
  1049. detail=str(e),
  1050. )
  1051. payload = {**form_data.model_dump(exclude_none=True, exclude=["metadata"])}
  1052. if "metadata" in payload:
  1053. del payload["metadata"]
  1054. model_id = form_data.model
  1055. if ":" not in model_id:
  1056. model_id = f"{model_id}:latest"
  1057. model_info = Models.get_model_by_id(model_id)
  1058. if model_info:
  1059. if model_info.base_model_id:
  1060. payload["model"] = model_info.base_model_id
  1061. params = model_info.params.model_dump()
  1062. if params:
  1063. payload = apply_model_params_to_body_openai(params, payload)
  1064. # Check if user has access to the model
  1065. if user.role == "user":
  1066. if not (
  1067. user.id == model_info.user_id
  1068. or has_access(
  1069. user.id, type="read", access_control=model_info.access_control
  1070. )
  1071. ):
  1072. raise HTTPException(
  1073. status_code=403,
  1074. detail="Model not found",
  1075. )
  1076. else:
  1077. if user.role != "admin":
  1078. raise HTTPException(
  1079. status_code=403,
  1080. detail="Model not found",
  1081. )
  1082. if ":" not in payload["model"]:
  1083. payload["model"] = f"{payload['model']}:latest"
  1084. url, url_idx = await get_ollama_url(request, payload["model"], url_idx)
  1085. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  1086. str(url_idx),
  1087. request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
  1088. )
  1089. prefix_id = api_config.get("prefix_id", None)
  1090. if prefix_id:
  1091. payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
  1092. return await send_post_request(
  1093. url=f"{url}/v1/completions",
  1094. payload=json.dumps(payload),
  1095. stream=payload.get("stream", False),
  1096. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  1097. user=user,
  1098. )
  1099. @router.post("/v1/chat/completions")
  1100. @router.post("/v1/chat/completions/{url_idx}")
  1101. async def generate_openai_chat_completion(
  1102. request: Request,
  1103. form_data: dict,
  1104. url_idx: Optional[int] = None,
  1105. user=Depends(get_verified_user),
  1106. ):
  1107. metadata = form_data.pop("metadata", None)
  1108. try:
  1109. completion_form = OpenAIChatCompletionForm(**form_data)
  1110. except Exception as e:
  1111. log.exception(e)
  1112. raise HTTPException(
  1113. status_code=400,
  1114. detail=str(e),
  1115. )
  1116. payload = {**completion_form.model_dump(exclude_none=True, exclude=["metadata"])}
  1117. if "metadata" in payload:
  1118. del payload["metadata"]
  1119. model_id = completion_form.model
  1120. if ":" not in model_id:
  1121. model_id = f"{model_id}:latest"
  1122. model_info = Models.get_model_by_id(model_id)
  1123. if model_info:
  1124. if model_info.base_model_id:
  1125. payload["model"] = model_info.base_model_id
  1126. params = model_info.params.model_dump()
  1127. if params:
  1128. payload = apply_model_params_to_body_openai(params, payload)
  1129. payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
  1130. # Check if user has access to the model
  1131. if user.role == "user":
  1132. if not (
  1133. user.id == model_info.user_id
  1134. or has_access(
  1135. user.id, type="read", access_control=model_info.access_control
  1136. )
  1137. ):
  1138. raise HTTPException(
  1139. status_code=403,
  1140. detail="Model not found",
  1141. )
  1142. else:
  1143. if user.role != "admin":
  1144. raise HTTPException(
  1145. status_code=403,
  1146. detail="Model not found",
  1147. )
  1148. if ":" not in payload["model"]:
  1149. payload["model"] = f"{payload['model']}:latest"
  1150. url, url_idx = await get_ollama_url(request, payload["model"], url_idx)
  1151. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  1152. str(url_idx),
  1153. request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
  1154. )
  1155. prefix_id = api_config.get("prefix_id", None)
  1156. if prefix_id:
  1157. payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
  1158. return await send_post_request(
  1159. url=f"{url}/v1/chat/completions",
  1160. payload=json.dumps(payload),
  1161. stream=payload.get("stream", False),
  1162. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  1163. user=user,
  1164. )
  1165. @router.get("/v1/models")
  1166. @router.get("/v1/models/{url_idx}")
  1167. async def get_openai_models(
  1168. request: Request,
  1169. url_idx: Optional[int] = None,
  1170. user=Depends(get_verified_user),
  1171. ):
  1172. models = []
  1173. if url_idx is None:
  1174. model_list = await get_all_models(request, user=user)
  1175. models = [
  1176. {
  1177. "id": model["model"],
  1178. "object": "model",
  1179. "created": int(time.time()),
  1180. "owned_by": "openai",
  1181. }
  1182. for model in model_list["models"]
  1183. ]
  1184. else:
  1185. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  1186. try:
  1187. r = requests.request(method="GET", url=f"{url}/api/tags")
  1188. r.raise_for_status()
  1189. model_list = r.json()
  1190. models = [
  1191. {
  1192. "id": model["model"],
  1193. "object": "model",
  1194. "created": int(time.time()),
  1195. "owned_by": "openai",
  1196. }
  1197. for model in models["models"]
  1198. ]
  1199. except Exception as e:
  1200. log.exception(e)
  1201. error_detail = "Open WebUI: Server Connection Error"
  1202. if r is not None:
  1203. try:
  1204. res = r.json()
  1205. if "error" in res:
  1206. error_detail = f"Ollama: {res['error']}"
  1207. except Exception:
  1208. error_detail = f"Ollama: {e}"
  1209. raise HTTPException(
  1210. status_code=r.status_code if r else 500,
  1211. detail=error_detail,
  1212. )
  1213. if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
  1214. # Filter models based on user access control
  1215. filtered_models = []
  1216. for model in models:
  1217. model_info = Models.get_model_by_id(model["id"])
  1218. if model_info:
  1219. if user.id == model_info.user_id or has_access(
  1220. user.id, type="read", access_control=model_info.access_control
  1221. ):
  1222. filtered_models.append(model)
  1223. models = filtered_models
  1224. return {
  1225. "data": models,
  1226. "object": "list",
  1227. }
  1228. class UrlForm(BaseModel):
  1229. url: str
  1230. class UploadBlobForm(BaseModel):
  1231. filename: str
  1232. def parse_huggingface_url(hf_url):
  1233. try:
  1234. # Parse the URL
  1235. parsed_url = urlparse(hf_url)
  1236. # Get the path and split it into components
  1237. path_components = parsed_url.path.split("/")
  1238. # Extract the desired output
  1239. model_file = path_components[-1]
  1240. return model_file
  1241. except ValueError:
  1242. return None
  1243. async def download_file_stream(
  1244. ollama_url, file_url, file_path, file_name, chunk_size=1024 * 1024
  1245. ):
  1246. done = False
  1247. if os.path.exists(file_path):
  1248. current_size = os.path.getsize(file_path)
  1249. else:
  1250. current_size = 0
  1251. headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {}
  1252. timeout = aiohttp.ClientTimeout(total=600) # Set the timeout
  1253. async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
  1254. async with session.get(
  1255. file_url, headers=headers, ssl=AIOHTTP_CLIENT_SESSION_SSL
  1256. ) as response:
  1257. total_size = int(response.headers.get("content-length", 0)) + current_size
  1258. with open(file_path, "ab+") as file:
  1259. async for data in response.content.iter_chunked(chunk_size):
  1260. current_size += len(data)
  1261. file.write(data)
  1262. done = current_size == total_size
  1263. progress = round((current_size / total_size) * 100, 2)
  1264. yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n'
  1265. if done:
  1266. file.seek(0)
  1267. chunk_size = 1024 * 1024 * 2
  1268. hashed = calculate_sha256(file, chunk_size)
  1269. file.seek(0)
  1270. url = f"{ollama_url}/api/blobs/sha256:{hashed}"
  1271. response = requests.post(url, data=file)
  1272. if response.ok:
  1273. res = {
  1274. "done": done,
  1275. "blob": f"sha256:{hashed}",
  1276. "name": file_name,
  1277. }
  1278. os.remove(file_path)
  1279. yield f"data: {json.dumps(res)}\n\n"
  1280. else:
  1281. raise "Ollama: Could not create blob, Please try again."
  1282. # url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
  1283. @router.post("/models/download")
  1284. @router.post("/models/download/{url_idx}")
  1285. async def download_model(
  1286. request: Request,
  1287. form_data: UrlForm,
  1288. url_idx: Optional[int] = None,
  1289. user=Depends(get_admin_user),
  1290. ):
  1291. allowed_hosts = ["https://huggingface.co/", "https://github.com/"]
  1292. if not any(form_data.url.startswith(host) for host in allowed_hosts):
  1293. raise HTTPException(
  1294. status_code=400,
  1295. detail="Invalid file_url. Only URLs from allowed hosts are permitted.",
  1296. )
  1297. if url_idx is None:
  1298. url_idx = 0
  1299. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  1300. file_name = parse_huggingface_url(form_data.url)
  1301. if file_name:
  1302. file_path = f"{UPLOAD_DIR}/{file_name}"
  1303. return StreamingResponse(
  1304. download_file_stream(url, form_data.url, file_path, file_name),
  1305. )
  1306. else:
  1307. return None
  1308. # TODO: Progress bar does not reflect size & duration of upload.
  1309. @router.post("/models/upload")
  1310. @router.post("/models/upload/{url_idx}")
  1311. async def upload_model(
  1312. request: Request,
  1313. file: UploadFile = File(...),
  1314. url_idx: Optional[int] = None,
  1315. user=Depends(get_admin_user),
  1316. ):
  1317. if url_idx is None:
  1318. url_idx = 0
  1319. ollama_url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  1320. filename = os.path.basename(file.filename)
  1321. file_path = os.path.join(UPLOAD_DIR, filename)
  1322. os.makedirs(UPLOAD_DIR, exist_ok=True)
  1323. # --- P1: save file locally ---
  1324. chunk_size = 1024 * 1024 * 2 # 2 MB chunks
  1325. with open(file_path, "wb") as out_f:
  1326. while True:
  1327. chunk = file.file.read(chunk_size)
  1328. # log.info(f"Chunk: {str(chunk)}") # DEBUG
  1329. if not chunk:
  1330. break
  1331. out_f.write(chunk)
  1332. async def file_process_stream():
  1333. nonlocal ollama_url
  1334. total_size = os.path.getsize(file_path)
  1335. log.info(f"Total Model Size: {str(total_size)}") # DEBUG
  1336. # --- P2: SSE progress + calculate sha256 hash ---
  1337. file_hash = calculate_sha256(file_path, chunk_size)
  1338. log.info(f"Model Hash: {str(file_hash)}") # DEBUG
  1339. try:
  1340. with open(file_path, "rb") as f:
  1341. bytes_read = 0
  1342. while chunk := f.read(chunk_size):
  1343. bytes_read += len(chunk)
  1344. progress = round(bytes_read / total_size * 100, 2)
  1345. data_msg = {
  1346. "progress": progress,
  1347. "total": total_size,
  1348. "completed": bytes_read,
  1349. }
  1350. yield f"data: {json.dumps(data_msg)}\n\n"
  1351. # --- P3: Upload to ollama /api/blobs ---
  1352. with open(file_path, "rb") as f:
  1353. url = f"{ollama_url}/api/blobs/sha256:{file_hash}"
  1354. response = requests.post(url, data=f)
  1355. if response.ok:
  1356. log.info(f"Uploaded to /api/blobs") # DEBUG
  1357. # Remove local file
  1358. os.remove(file_path)
  1359. # Create model in ollama
  1360. model_name, ext = os.path.splitext(filename)
  1361. log.info(f"Created Model: {model_name}") # DEBUG
  1362. create_payload = {
  1363. "model": model_name,
  1364. # Reference the file by its original name => the uploaded blob's digest
  1365. "files": {filename: f"sha256:{file_hash}"},
  1366. }
  1367. log.info(f"Model Payload: {create_payload}") # DEBUG
  1368. # Call ollama /api/create
  1369. # https://github.com/ollama/ollama/blob/main/docs/api.md#create-a-model
  1370. create_resp = requests.post(
  1371. url=f"{ollama_url}/api/create",
  1372. headers={"Content-Type": "application/json"},
  1373. data=json.dumps(create_payload),
  1374. )
  1375. if create_resp.ok:
  1376. log.info(f"API SUCCESS!") # DEBUG
  1377. done_msg = {
  1378. "done": True,
  1379. "blob": f"sha256:{file_hash}",
  1380. "name": filename,
  1381. "model_created": model_name,
  1382. }
  1383. yield f"data: {json.dumps(done_msg)}\n\n"
  1384. else:
  1385. raise Exception(
  1386. f"Failed to create model in Ollama. {create_resp.text}"
  1387. )
  1388. else:
  1389. raise Exception("Ollama: Could not create blob, Please try again.")
  1390. except Exception as e:
  1391. res = {"error": str(e)}
  1392. yield f"data: {json.dumps(res)}\n\n"
  1393. return StreamingResponse(file_process_stream(), media_type="text/event-stream")