|
@@ -17,10 +17,14 @@ from open_webui.config import (
|
|
ENABLE_OLLAMA_API,
|
|
ENABLE_OLLAMA_API,
|
|
MODEL_FILTER_LIST,
|
|
MODEL_FILTER_LIST,
|
|
OLLAMA_BASE_URLS,
|
|
OLLAMA_BASE_URLS,
|
|
|
|
+ OLLAMA_API_CONFIGS,
|
|
UPLOAD_DIR,
|
|
UPLOAD_DIR,
|
|
AppConfig,
|
|
AppConfig,
|
|
)
|
|
)
|
|
-from open_webui.env import AIOHTTP_CLIENT_TIMEOUT
|
|
|
|
|
|
+from open_webui.env import (
|
|
|
|
+ AIOHTTP_CLIENT_TIMEOUT,
|
|
|
|
+ AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
|
|
|
|
+)
|
|
|
|
|
|
|
|
|
|
from open_webui.constants import ERROR_MESSAGES
|
|
from open_webui.constants import ERROR_MESSAGES
|
|
@@ -67,6 +71,8 @@ app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
|
|
|
|
|
app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
|
|
app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
|
|
app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
|
|
app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
|
|
|
|
+app.state.config.OLLAMA_API_CONFIGS = OLLAMA_API_CONFIGS
|
|
|
|
+
|
|
app.state.MODELS = {}
|
|
app.state.MODELS = {}
|
|
|
|
|
|
|
|
|
|
@@ -92,17 +98,64 @@ async def get_status():
|
|
return {"status": True}
|
|
return {"status": True}
|
|
|
|
|
|
|
|
|
|
|
|
+class ConnectionVerificationForm(BaseModel):
|
|
|
|
+ url: str
|
|
|
|
+ key: Optional[str] = None
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@app.post("/verify")
|
|
|
|
+async def verify_connection(
|
|
|
|
+ form_data: ConnectionVerificationForm, user=Depends(get_admin_user)
|
|
|
|
+):
|
|
|
|
+ url = form_data.url
|
|
|
|
+ key = form_data.key
|
|
|
|
+
|
|
|
|
+ headers = {}
|
|
|
|
+ if key:
|
|
|
|
+ headers["Authorization"] = f"Bearer {key}"
|
|
|
|
+
|
|
|
|
+ timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
|
|
|
+ async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
|
|
+ try:
|
|
|
|
+ async with session.get(f"{url}/api/version", headers=headers) as r:
|
|
|
|
+ if r.status != 200:
|
|
|
|
+ # Extract response error details if available
|
|
|
|
+ error_detail = f"HTTP Error: {r.status}"
|
|
|
|
+ res = await r.json()
|
|
|
|
+ if "error" in res:
|
|
|
|
+ error_detail = f"External Error: {res['error']}"
|
|
|
|
+ raise Exception(error_detail)
|
|
|
|
+
|
|
|
|
+ response_data = await r.json()
|
|
|
|
+ return response_data
|
|
|
|
+
|
|
|
|
+ except aiohttp.ClientError as e:
|
|
|
|
+ # ClientError covers all aiohttp requests issues
|
|
|
|
+ log.exception(f"Client error: {str(e)}")
|
|
|
|
+ # Handle aiohttp-specific connection issues, timeout etc.
|
|
|
|
+ raise HTTPException(
|
|
|
|
+ status_code=500, detail="Open WebUI: Server Connection Error"
|
|
|
|
+ )
|
|
|
|
+ except Exception as e:
|
|
|
|
+ log.exception(f"Unexpected error: {e}")
|
|
|
|
+ # Generic error handler in case parsing JSON or other steps fail
|
|
|
|
+ error_detail = f"Unexpected error: {str(e)}"
|
|
|
|
+ raise HTTPException(status_code=500, detail=error_detail)
|
|
|
|
+
|
|
|
|
+
|
|
@app.get("/config")
|
|
@app.get("/config")
|
|
async def get_config(user=Depends(get_admin_user)):
|
|
async def get_config(user=Depends(get_admin_user)):
|
|
return {
|
|
return {
|
|
"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API,
|
|
"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API,
|
|
"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS,
|
|
"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS,
|
|
|
|
+ "OLLAMA_API_CONFIGS": app.state.config.OLLAMA_API_CONFIGS,
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
class OllamaConfigForm(BaseModel):
|
|
class OllamaConfigForm(BaseModel):
|
|
ENABLE_OLLAMA_API: Optional[bool] = None
|
|
ENABLE_OLLAMA_API: Optional[bool] = None
|
|
OLLAMA_BASE_URLS: list[str]
|
|
OLLAMA_BASE_URLS: list[str]
|
|
|
|
+ OLLAMA_API_CONFIGS: dict
|
|
|
|
|
|
|
|
|
|
@app.post("/config/update")
|
|
@app.post("/config/update")
|
|
@@ -110,17 +163,27 @@ async def update_config(form_data: OllamaConfigForm, user=Depends(get_admin_user
|
|
app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API
|
|
app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API
|
|
app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS
|
|
app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS
|
|
|
|
|
|
|
|
+ app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS
|
|
|
|
+
|
|
|
|
+ # Remove any extra configs
|
|
|
|
+ config_urls = app.state.config.OLLAMA_API_CONFIGS.keys()
|
|
|
|
+ for url in list(app.state.config.OLLAMA_BASE_URLS):
|
|
|
|
+ if url not in config_urls:
|
|
|
|
+ app.state.config.OLLAMA_API_CONFIGS.pop(url, None)
|
|
|
|
+
|
|
return {
|
|
return {
|
|
"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API,
|
|
"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API,
|
|
"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS,
|
|
"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS,
|
|
|
|
+ "OLLAMA_API_CONFIGS": app.state.config.OLLAMA_API_CONFIGS,
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
-async def fetch_url(url):
|
|
|
|
- timeout = aiohttp.ClientTimeout(total=3)
|
|
|
|
|
|
+async def aiohttp_get(url, key=None):
|
|
|
|
+ timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
|
|
try:
|
|
try:
|
|
|
|
+ headers = {"Authorization": f"Bearer {key}"} if key else {}
|
|
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
|
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
|
- async with session.get(url) as response:
|
|
|
|
|
|
+ async with session.get(url, headers=headers) as response:
|
|
return await response.json()
|
|
return await response.json()
|
|
except Exception as e:
|
|
except Exception as e:
|
|
# Handle connection error here
|
|
# Handle connection error here
|
|
@@ -204,13 +267,42 @@ def merge_models_lists(model_lists):
|
|
|
|
|
|
async def get_all_models():
|
|
async def get_all_models():
|
|
log.info("get_all_models()")
|
|
log.info("get_all_models()")
|
|
-
|
|
|
|
if app.state.config.ENABLE_OLLAMA_API:
|
|
if app.state.config.ENABLE_OLLAMA_API:
|
|
- tasks = [
|
|
|
|
- fetch_url(f"{url}/api/tags") for url in app.state.config.OLLAMA_BASE_URLS
|
|
|
|
- ]
|
|
|
|
|
|
+ tasks = []
|
|
|
|
+ for idx, url in enumerate(app.state.config.OLLAMA_BASE_URLS):
|
|
|
|
+ if url not in app.state.config.OLLAMA_API_CONFIGS:
|
|
|
|
+ tasks.append(aiohttp_get(f"{url}/api/tags"))
|
|
|
|
+ else:
|
|
|
|
+ api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
|
|
|
+ enable = api_config.get("enable", True)
|
|
|
|
+
|
|
|
|
+ if enable:
|
|
|
|
+ tasks.append(aiohttp_get(f"{url}/api/tags"))
|
|
|
|
+ else:
|
|
|
|
+ tasks.append(None)
|
|
|
|
+
|
|
responses = await asyncio.gather(*tasks)
|
|
responses = await asyncio.gather(*tasks)
|
|
|
|
|
|
|
|
+ for idx, response in enumerate(responses):
|
|
|
|
+ if response:
|
|
|
|
+ url = app.state.config.OLLAMA_BASE_URLS[idx]
|
|
|
|
+ api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
|
|
|
+
|
|
|
|
+ prefix_id = api_config.get("prefix_id", None)
|
|
|
|
+ model_ids = api_config.get("model_ids", [])
|
|
|
|
+
|
|
|
|
+ if len(model_ids) != 0:
|
|
|
|
+ response["models"] = list(
|
|
|
|
+ filter(
|
|
|
|
+ lambda model: model["model"] in model_ids,
|
|
|
|
+ response["models"],
|
|
|
|
+ )
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ if prefix_id:
|
|
|
|
+ for model in response["models"]:
|
|
|
|
+ model["model"] = f"{prefix_id}.{model['model']}"
|
|
|
|
+
|
|
models = {
|
|
models = {
|
|
"models": merge_models_lists(
|
|
"models": merge_models_lists(
|
|
map(
|
|
map(
|
|
@@ -279,7 +371,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
|
|
if url_idx is None:
|
|
if url_idx is None:
|
|
# returns lowest version
|
|
# returns lowest version
|
|
tasks = [
|
|
tasks = [
|
|
- fetch_url(f"{url}/api/version")
|
|
|
|
|
|
+ aiohttp_get(f"{url}/api/version")
|
|
for url in app.state.config.OLLAMA_BASE_URLS
|
|
for url in app.state.config.OLLAMA_BASE_URLS
|
|
]
|
|
]
|
|
responses = await asyncio.gather(*tasks)
|
|
responses = await asyncio.gather(*tasks)
|
|
@@ -718,6 +810,10 @@ async def generate_completion(
|
|
)
|
|
)
|
|
|
|
|
|
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
|
|
+ api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
|
|
|
+ prefix_id = api_config.get("prefix_id", None)
|
|
|
|
+ if prefix_id:
|
|
|
|
+ form_data.model = form_data.model.replace(f"{prefix_id}.", "")
|
|
log.info(f"url: {url}")
|
|
log.info(f"url: {url}")
|
|
|
|
|
|
return await post_streaming_url(
|
|
return await post_streaming_url(
|
|
@@ -799,6 +895,11 @@ async def generate_chat_completion(
|
|
log.info(f"url: {url}")
|
|
log.info(f"url: {url}")
|
|
log.debug(f"generate_chat_completion() - 2.payload = {payload}")
|
|
log.debug(f"generate_chat_completion() - 2.payload = {payload}")
|
|
|
|
|
|
|
|
+ api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
|
|
|
+ prefix_id = api_config.get("prefix_id", None)
|
|
|
|
+ if prefix_id:
|
|
|
|
+ payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
|
|
|
|
+
|
|
return await post_streaming_url(
|
|
return await post_streaming_url(
|
|
f"{url}/api/chat",
|
|
f"{url}/api/chat",
|
|
json.dumps(payload),
|
|
json.dumps(payload),
|
|
@@ -874,6 +975,11 @@ async def generate_openai_chat_completion(
|
|
url = get_ollama_url(url_idx, payload["model"])
|
|
url = get_ollama_url(url_idx, payload["model"])
|
|
log.info(f"url: {url}")
|
|
log.info(f"url: {url}")
|
|
|
|
|
|
|
|
+ api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
|
|
|
+ prefix_id = api_config.get("prefix_id", None)
|
|
|
|
+ if prefix_id:
|
|
|
|
+ payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
|
|
|
|
+
|
|
return await post_streaming_url(
|
|
return await post_streaming_url(
|
|
f"{url}/v1/chat/completions",
|
|
f"{url}/v1/chat/completions",
|
|
json.dumps(payload),
|
|
json.dumps(payload),
|