فهرست منبع

refac: ollama ps

Timothy Jaeryang Baek 4 ماه پیش
والد
کامیت
a50a8e2ef9
1فایلهای تغییر یافته به همراه83 افزوده شده و 42 حذف شده
  1. 83 42
      backend/open_webui/routers/ollama.py

+ 83 - 42
backend/open_webui/routers/ollama.py

@@ -300,6 +300,22 @@ async def update_config(
     }
 
 
+def merge_ollama_models_lists(model_lists):
+    merged_models = {}
+
+    for idx, model_list in enumerate(model_lists):
+        if model_list is not None:
+            for model in model_list:
+                id = model["model"]
+                if id not in merged_models:
+                    model["urls"] = [idx]
+                    merged_models[id] = model
+                else:
+                    merged_models[id]["urls"].append(idx)
+
+    return list(merged_models.values())
+
+
 @cached(ttl=1)
 async def get_all_models(request: Request, user: UserModel = None):
     log.info("get_all_models()")
@@ -364,23 +380,8 @@ async def get_all_models(request: Request, user: UserModel = None):
                     if connection_type:
                         model["connection_type"] = connection_type
 
-        def merge_models_lists(model_lists):
-            merged_models = {}
-
-            for idx, model_list in enumerate(model_lists):
-                if model_list is not None:
-                    for model in model_list:
-                        id = model["model"]
-                        if id not in merged_models:
-                            model["urls"] = [idx]
-                            merged_models[id] = model
-                        else:
-                            merged_models[id]["urls"].append(idx)
-
-            return list(merged_models.values())
-
         models = {
-            "models": merge_models_lists(
+            "models": merge_ollama_models_lists(
                 map(
                     lambda response: response.get("models", []) if response else None,
                     responses,
@@ -468,6 +469,72 @@ async def get_ollama_tags(
     return models
 
 
+@router.get("/api/ps")
+async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_user)):
+    """
+    List models that are currently loaded into Ollama memory, and which node they are loaded on.
+    """
+    if request.app.state.config.ENABLE_OLLAMA_API:
+        request_tasks = []
+        for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
+            if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and (
+                url not in request.app.state.config.OLLAMA_API_CONFIGS  # Legacy support
+            ):
+                request_tasks.append(send_get_request(f"{url}/api/ps", user=user))
+            else:
+                api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
+                    str(idx),
+                    request.app.state.config.OLLAMA_API_CONFIGS.get(
+                        url, {}
+                    ),  # Legacy support
+                )
+
+                enable = api_config.get("enable", True)
+                key = api_config.get("key", None)
+
+                if enable:
+                    request_tasks.append(
+                        send_get_request(f"{url}/api/ps", key, user=user)
+                    )
+                else:
+                    request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
+
+        responses = await asyncio.gather(*request_tasks)
+
+        for idx, response in enumerate(responses):
+            if response:
+                url = request.app.state.config.OLLAMA_BASE_URLS[idx]
+                api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
+                    str(idx),
+                    request.app.state.config.OLLAMA_API_CONFIGS.get(
+                        url, {}
+                    ),  # Legacy support
+                )
+
+                prefix_id = api_config.get("prefix_id", None)
+
+                for model in response.get("models", []):
+                    if prefix_id:
+                        model["model"] = f"{prefix_id}.{model['model']}"
+
+        models = {
+            "models": merge_ollama_models_lists(
+                map(
+                    lambda response: response.get("models", []) if response else None,
+                    responses,
+                )
+            )
+        }
+
+        if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
+            models["models"] = await get_filtered_models(models, user)
+
+    else:
+        models = {"models": []}
+
+    return models
+
+
 @router.get("/api/version")
 @router.get("/api/version/{url_idx}")
 async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
@@ -541,32 +608,6 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
         return {"version": False}
 
 
-@router.get("/api/ps")
-async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_user)):
-    """
-    List models that are currently loaded into Ollama memory, and which node they are loaded on.
-    """
-    if request.app.state.config.ENABLE_OLLAMA_API:
-        request_tasks = [
-            send_get_request(
-                f"{url}/api/ps",
-                request.app.state.config.OLLAMA_API_CONFIGS.get(
-                    str(idx),
-                    request.app.state.config.OLLAMA_API_CONFIGS.get(
-                        url, {}
-                    ),  # Legacy support
-                ).get("key", None),
-                user=user,
-            )
-            for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS)
-        ]
-        responses = await asyncio.gather(*request_tasks)
-
-        return dict(zip(request.app.state.config.OLLAMA_BASE_URLS, responses))
-    else:
-        return {}
-
-
 class ModelNameForm(BaseModel):
     name: str