Sfoglia il codice sorgente

Merge pull request #14402 from torisetxd/parallelized-model-fetching

perf: Parallelize base model fetching
Tim Jaeryang Baek 4 mesi fa
parent
commit
100a764293
1 ha cambiato i file con 38 aggiunte e 26 eliminazioni
  1. 38 26
      backend/open_webui/utils/models.py

+ 38 - 26
backend/open_webui/utils/models.py

@@ -1,5 +1,6 @@
 import time
 import logging
+import asyncio
 import sys
 
 from aiocache import cached
@@ -33,35 +34,46 @@ log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MAIN"])
 
 
-async def get_all_base_models(request: Request, user: UserModel = None):
-    function_models = []
-    openai_models = []
-    ollama_models = []
-
-    if request.app.state.config.ENABLE_OPENAI_API:
-        openai_models = await openai.get_all_models(request, user=user)
-        openai_models = openai_models["data"]
+async def fetch_ollama_models(request: Request, user: UserModel = None):
+    raw_ollama_models = await ollama.get_all_models(request, user=user)
+    return [
+        {
+            "id": model["model"],
+            "name": model["name"],
+            "object": "model",
+            "created": int(time.time()),
+            "owned_by": "ollama",
+            "ollama": model,
+            "connection_type": model.get("connection_type", "local"),
+            "tags": model.get("tags", []),
+        }
+        for model in raw_ollama_models["models"]
+    ]
+    
 
-    if request.app.state.config.ENABLE_OLLAMA_API:
-        ollama_models = await ollama.get_all_models(request, user=user)
-        ollama_models = [
-            {
-                "id": model["model"],
-                "name": model["name"],
-                "object": "model",
-                "created": int(time.time()),
-                "owned_by": "ollama",
-                "ollama": model,
-                "connection_type": model.get("connection_type", "local"),
-                "tags": model.get("tags", []),
-            }
-            for model in ollama_models["models"]
-        ]
+async def fetch_openai_models(request: Request, user: UserModel = None):
+    openai_response = await openai.get_all_models(request, user=user)
+    return openai_response["data"]
 
-    function_models = await get_function_models(request)
-    models = function_models + openai_models + ollama_models
 
-    return models
+async def get_all_base_models(request: Request, user: UserModel = None):
+    openai_task = (
+        fetch_openai_models(request, user)
+        if request.app.state.config.ENABLE_OPENAI_API
+        else asyncio.sleep(0, result=[])
+    )
+    ollama_task = (
+        fetch_ollama_models(request, user)
+        if request.app.state.config.ENABLE_OLLAMA_API
+        else asyncio.sleep(0, result=[])
+    )
+    function_task = get_function_models(request)
+
+    openai_models, ollama_models, function_models = await asyncio.gather(
+        openai_task, ollama_task, function_task
+    )
+
+    return function_models + openai_models + ollama_models
 
 
 async def get_all_models(request, user: UserModel = None):