|
@@ -300,6 +300,22 @@ async def update_config(
|
|
|
}
|
|
|
|
|
|
|
|
|
+def merge_ollama_models_lists(model_lists):
|
|
|
+ merged_models = {}
|
|
|
+
|
|
|
+ for idx, model_list in enumerate(model_lists):
|
|
|
+ if model_list is not None:
|
|
|
+ for model in model_list:
|
|
|
+ id = model["model"]
|
|
|
+ if id not in merged_models:
|
|
|
+ model["urls"] = [idx]
|
|
|
+ merged_models[id] = model
|
|
|
+ else:
|
|
|
+ merged_models[id]["urls"].append(idx)
|
|
|
+
|
|
|
+ return list(merged_models.values())
|
|
|
+
|
|
|
+
|
|
|
@cached(ttl=1)
|
|
|
async def get_all_models(request: Request, user: UserModel = None):
|
|
|
log.info("get_all_models()")
|
|
@@ -364,23 +380,8 @@ async def get_all_models(request: Request, user: UserModel = None):
|
|
|
if connection_type:
|
|
|
model["connection_type"] = connection_type
|
|
|
|
|
|
- def merge_models_lists(model_lists):
|
|
|
- merged_models = {}
|
|
|
-
|
|
|
- for idx, model_list in enumerate(model_lists):
|
|
|
- if model_list is not None:
|
|
|
- for model in model_list:
|
|
|
- id = model["model"]
|
|
|
- if id not in merged_models:
|
|
|
- model["urls"] = [idx]
|
|
|
- merged_models[id] = model
|
|
|
- else:
|
|
|
- merged_models[id]["urls"].append(idx)
|
|
|
-
|
|
|
- return list(merged_models.values())
|
|
|
-
|
|
|
models = {
|
|
|
- "models": merge_models_lists(
|
|
|
+ "models": merge_ollama_models_lists(
|
|
|
map(
|
|
|
lambda response: response.get("models", []) if response else None,
|
|
|
responses,
|
|
@@ -468,6 +469,72 @@ async def get_ollama_tags(
|
|
|
return models
|
|
|
|
|
|
|
|
|
+@router.get("/api/ps")
|
|
|
+async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_user)):
|
|
|
+ """
|
|
|
+ List models that are currently loaded into Ollama memory, and which node they are loaded on.
|
|
|
+ """
|
|
|
+ if request.app.state.config.ENABLE_OLLAMA_API:
|
|
|
+ request_tasks = []
|
|
|
+ for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
|
|
|
+ if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and (
|
|
|
+ url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support
|
|
|
+ ):
|
|
|
+ request_tasks.append(send_get_request(f"{url}/api/ps", user=user))
|
|
|
+ else:
|
|
|
+ api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
|
|
+ str(idx),
|
|
|
+ request.app.state.config.OLLAMA_API_CONFIGS.get(
|
|
|
+ url, {}
|
|
|
+ ), # Legacy support
|
|
|
+ )
|
|
|
+
|
|
|
+ enable = api_config.get("enable", True)
|
|
|
+ key = api_config.get("key", None)
|
|
|
+
|
|
|
+ if enable:
|
|
|
+ request_tasks.append(
|
|
|
+ send_get_request(f"{url}/api/ps", key, user=user)
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
|
|
|
+
|
|
|
+ responses = await asyncio.gather(*request_tasks)
|
|
|
+
|
|
|
+ for idx, response in enumerate(responses):
|
|
|
+ if response:
|
|
|
+ url = request.app.state.config.OLLAMA_BASE_URLS[idx]
|
|
|
+ api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
|
|
+ str(idx),
|
|
|
+ request.app.state.config.OLLAMA_API_CONFIGS.get(
|
|
|
+ url, {}
|
|
|
+ ), # Legacy support
|
|
|
+ )
|
|
|
+
|
|
|
+ prefix_id = api_config.get("prefix_id", None)
|
|
|
+
|
|
|
+ for model in response.get("models", []):
|
|
|
+ if prefix_id:
|
|
|
+ model["model"] = f"{prefix_id}.{model['model']}"
|
|
|
+
|
|
|
+ models = {
|
|
|
+ "models": merge_ollama_models_lists(
|
|
|
+ map(
|
|
|
+ lambda response: response.get("models", []) if response else None,
|
|
|
+ responses,
|
|
|
+ )
|
|
|
+ )
|
|
|
+ }
|
|
|
+
|
|
|
+ if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
|
|
|
+ models["models"] = await get_filtered_models(models, user)
|
|
|
+
|
|
|
+ else:
|
|
|
+ models = {"models": []}
|
|
|
+
|
|
|
+ return models
|
|
|
+
|
|
|
+
|
|
|
@router.get("/api/version")
|
|
|
@router.get("/api/version/{url_idx}")
|
|
|
async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
|
|
@@ -541,32 +608,6 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
|
|
|
return {"version": False}
|
|
|
|
|
|
|
|
|
-@router.get("/api/ps")
|
|
|
-async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_user)):
|
|
|
- """
|
|
|
- List models that are currently loaded into Ollama memory, and which node they are loaded on.
|
|
|
- """
|
|
|
- if request.app.state.config.ENABLE_OLLAMA_API:
|
|
|
- request_tasks = [
|
|
|
- send_get_request(
|
|
|
- f"{url}/api/ps",
|
|
|
- request.app.state.config.OLLAMA_API_CONFIGS.get(
|
|
|
- str(idx),
|
|
|
- request.app.state.config.OLLAMA_API_CONFIGS.get(
|
|
|
- url, {}
|
|
|
- ), # Legacy support
|
|
|
- ).get("key", None),
|
|
|
- user=user,
|
|
|
- )
|
|
|
- for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS)
|
|
|
- ]
|
|
|
- responses = await asyncio.gather(*request_tasks)
|
|
|
-
|
|
|
- return dict(zip(request.app.state.config.OLLAMA_BASE_URLS, responses))
|
|
|
- else:
|
|
|
- return {}
|
|
|
-
|
|
|
-
|
|
|
class ModelNameForm(BaseModel):
|
|
|
name: str
|
|
|
|