Selaa lähdekoodia

refac: ENABLE_MODEL_LIST_CACHE -> ENABLE_BASE_MODELS_CACHE

Timothy Jaeryang Baek 3 kuukautta sitten
vanhempi
commit
8a334decf6

+ 5 - 5
backend/open_webui/config.py

@@ -931,13 +931,13 @@ OPENAI_API_BASE_URL = "https://api.openai.com/v1"
 
 
 ####################################
-# MODEL_LIST
+# MODELS
 ####################################
 
-ENABLE_MODEL_LIST_CACHE = PersistentConfig(
-    "ENABLE_MODEL_LIST_CACHE",
-    "models.cache",
-    os.environ.get("ENABLE_MODEL_LIST_CACHE", "False").lower() == "true",
+ENABLE_BASE_MODELS_CACHE = PersistentConfig(
+    "ENABLE_BASE_MODELS_CACHE",
+    "models.base_models_cache",
+    os.environ.get("ENABLE_BASE_MODELS_CACHE", "False").lower() == "true",
 )
 
 

+ 5 - 4
backend/open_webui/main.py

@@ -117,7 +117,7 @@ from open_webui.config import (
     # Direct Connections
     ENABLE_DIRECT_CONNECTIONS,
     # Model list
-    ENABLE_MODEL_LIST_CACHE,
+    ENABLE_BASE_MODELS_CACHE,
     # Thread pool size for FastAPI/AnyIO
     THREAD_POOL_SIZE,
     # Tool Server Configs
@@ -537,7 +537,7 @@ async def lifespan(app: FastAPI):
 
     asyncio.create_task(periodic_usage_pool_cleanup())
 
-    if app.state.config.ENABLE_MODEL_LIST_CACHE:
+    if app.state.config.ENABLE_BASE_MODELS_CACHE:
         await get_all_models(
             Request(
                 # Creating a mock request object to pass to get_all_models
@@ -643,11 +643,12 @@ app.state.config.ENABLE_DIRECT_CONNECTIONS = ENABLE_DIRECT_CONNECTIONS
 
 ########################################
 #
-# MODEL LIST
+# MODELS
 #
 ########################################
 
-app.state.config.ENABLE_MODEL_LIST_CACHE = ENABLE_MODEL_LIST_CACHE
+app.state.config.ENABLE_BASE_MODELS_CACHE = ENABLE_BASE_MODELS_CACHE
+app.state.BASE_MODELS = []
 
 ########################################
 #

+ 6 - 4
backend/open_webui/routers/configs.py

@@ -45,14 +45,14 @@ async def export_config(user=Depends(get_admin_user)):
 
 class ConnectionsConfigForm(BaseModel):
     ENABLE_DIRECT_CONNECTIONS: bool
-    ENABLE_MODEL_LIST_CACHE: bool
+    ENABLE_BASE_MODELS_CACHE: bool
 
 
 @router.get("/connections", response_model=ConnectionsConfigForm)
 async def get_connections_config(request: Request, user=Depends(get_admin_user)):
     return {
         "ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
-        "ENABLE_MODEL_LIST_CACHE": request.app.state.config.ENABLE_MODEL_LIST_CACHE,
+        "ENABLE_BASE_MODELS_CACHE": request.app.state.config.ENABLE_BASE_MODELS_CACHE,
     }
 
 
@@ -65,11 +65,13 @@ async def set_connections_config(
     request.app.state.config.ENABLE_DIRECT_CONNECTIONS = (
         form_data.ENABLE_DIRECT_CONNECTIONS
     )
-    request.app.state.config.ENABLE_MODEL_LIST_CACHE = form_data.ENABLE_MODEL_LIST_CACHE
+    request.app.state.config.ENABLE_BASE_MODELS_CACHE = (
+        form_data.ENABLE_BASE_MODELS_CACHE
+    )
 
     return {
         "ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
-        "ENABLE_MODEL_LIST_CACHE": request.app.state.config.ENABLE_MODEL_LIST_CACHE,
+        "ENABLE_BASE_MODELS_CACHE": request.app.state.config.ENABLE_BASE_MODELS_CACHE,
     }
 
 

+ 200 - 201
backend/open_webui/utils/models.py

@@ -77,176 +77,166 @@ async def get_all_base_models(request: Request, user: UserModel = None):
 
 
 async def get_all_models(request, refresh: bool = False, user: UserModel = None):
-    if request.app.state.MODELS and (
-        request.app.state.config.ENABLE_MODEL_LIST_CACHE and not refresh
+    if (
+        request.app.state.MODELS
+        and request.app.state.BASE_MODELS
+        and (request.app.state.config.ENABLE_BASE_MODELS_CACHE and not refresh)
     ):
-        return list(request.app.state.MODELS.values())
+        models = request.app.state.BASE_MODELS
     else:
         models = await get_all_base_models(request, user=user)
+        request.app.state.BASE_MODELS = models
 
-        # If there are no models, return an empty list
-        if len(models) == 0:
-            return []
-
-        # Add arena models
-        if request.app.state.config.ENABLE_EVALUATION_ARENA_MODELS:
-            arena_models = []
-            if len(request.app.state.config.EVALUATION_ARENA_MODELS) > 0:
-                arena_models = [
-                    {
-                        "id": model["id"],
-                        "name": model["name"],
-                        "info": {
-                            "meta": model["meta"],
-                        },
-                        "object": "model",
-                        "created": int(time.time()),
-                        "owned_by": "arena",
-                        "arena": True,
-                    }
-                    for model in request.app.state.config.EVALUATION_ARENA_MODELS
-                ]
-            else:
-                # Add default arena model
-                arena_models = [
-                    {
-                        "id": DEFAULT_ARENA_MODEL["id"],
-                        "name": DEFAULT_ARENA_MODEL["name"],
-                        "info": {
-                            "meta": DEFAULT_ARENA_MODEL["meta"],
-                        },
-                        "object": "model",
-                        "created": int(time.time()),
-                        "owned_by": "arena",
-                        "arena": True,
-                    }
-                ]
-            models = models + arena_models
-
-        global_action_ids = [
-            function.id for function in Functions.get_global_action_functions()
-        ]
-        enabled_action_ids = [
-            function.id
-            for function in Functions.get_functions_by_type("action", active_only=True)
-        ]
+    # If there are no models, return an empty list
+    if len(models) == 0:
+        return []
 
-        global_filter_ids = [
-            function.id for function in Functions.get_global_filter_functions()
-        ]
-        enabled_filter_ids = [
-            function.id
-            for function in Functions.get_functions_by_type("filter", active_only=True)
-        ]
+    # Add arena models
+    if request.app.state.config.ENABLE_EVALUATION_ARENA_MODELS:
+        arena_models = []
+        if len(request.app.state.config.EVALUATION_ARENA_MODELS) > 0:
+            arena_models = [
+                {
+                    "id": model["id"],
+                    "name": model["name"],
+                    "info": {
+                        "meta": model["meta"],
+                    },
+                    "object": "model",
+                    "created": int(time.time()),
+                    "owned_by": "arena",
+                    "arena": True,
+                }
+                for model in request.app.state.config.EVALUATION_ARENA_MODELS
+            ]
+        else:
+            # Add default arena model
+            arena_models = [
+                {
+                    "id": DEFAULT_ARENA_MODEL["id"],
+                    "name": DEFAULT_ARENA_MODEL["name"],
+                    "info": {
+                        "meta": DEFAULT_ARENA_MODEL["meta"],
+                    },
+                    "object": "model",
+                    "created": int(time.time()),
+                    "owned_by": "arena",
+                    "arena": True,
+                }
+            ]
+        models = models + arena_models
 
-        custom_models = Models.get_all_models()
-        for custom_model in custom_models:
-            if custom_model.base_model_id is None:
-                for model in models:
-                    if custom_model.id == model["id"] or (
-                        model.get("owned_by") == "ollama"
-                        and custom_model.id
-                        == model["id"].split(":")[
-                            0
-                        ]  # Ollama may return model ids in different formats (e.g., 'llama3' vs. 'llama3:7b')
-                    ):
-                        if custom_model.is_active:
-                            model["name"] = custom_model.name
-                            model["info"] = custom_model.model_dump()
-
-                            # Set action_ids and filter_ids
-                            action_ids = []
-                            filter_ids = []
-
-                            if "info" in model and "meta" in model["info"]:
-                                action_ids.extend(
-                                    model["info"]["meta"].get("actionIds", [])
-                                )
-                                filter_ids.extend(
-                                    model["info"]["meta"].get("filterIds", [])
-                                )
-
-                            model["action_ids"] = action_ids
-                            model["filter_ids"] = filter_ids
-                        else:
-                            models.remove(model)
-
-            elif custom_model.is_active and (
-                custom_model.id not in [model["id"] for model in models]
-            ):
-                owned_by = "openai"
-                pipe = None
-
-                action_ids = []
-                filter_ids = []
-
-                for model in models:
-                    if (
-                        custom_model.base_model_id == model["id"]
-                        or custom_model.base_model_id == model["id"].split(":")[0]
-                    ):
-                        owned_by = model.get("owned_by", "unknown owner")
-                        if "pipe" in model:
-                            pipe = model["pipe"]
-                        break
-
-                if custom_model.meta:
-                    meta = custom_model.meta.model_dump()
-
-                    if "actionIds" in meta:
-                        action_ids.extend(meta["actionIds"])
-
-                    if "filterIds" in meta:
-                        filter_ids.extend(meta["filterIds"])
-
-                models.append(
-                    {
-                        "id": f"{custom_model.id}",
-                        "name": custom_model.name,
-                        "object": "model",
-                        "created": custom_model.created_at,
-                        "owned_by": owned_by,
-                        "info": custom_model.model_dump(),
-                        "preset": True,
-                        **({"pipe": pipe} if pipe is not None else {}),
-                        "action_ids": action_ids,
-                        "filter_ids": filter_ids,
-                    }
-                )
+    global_action_ids = [
+        function.id for function in Functions.get_global_action_functions()
+    ]
+    enabled_action_ids = [
+        function.id
+        for function in Functions.get_functions_by_type("action", active_only=True)
+    ]
+
+    global_filter_ids = [
+        function.id for function in Functions.get_global_filter_functions()
+    ]
+    enabled_filter_ids = [
+        function.id
+        for function in Functions.get_functions_by_type("filter", active_only=True)
+    ]
+
+    custom_models = Models.get_all_models()
+    for custom_model in custom_models:
+        if custom_model.base_model_id is None:
+            for model in models:
+                if custom_model.id == model["id"] or (
+                    model.get("owned_by") == "ollama"
+                    and custom_model.id
+                    == model["id"].split(":")[
+                        0
+                    ]  # Ollama may return model ids in different formats (e.g., 'llama3' vs. 'llama3:7b')
+                ):
+                    if custom_model.is_active:
+                        model["name"] = custom_model.name
+                        model["info"] = custom_model.model_dump()
+
+                        # Set action_ids and filter_ids
+                        action_ids = []
+                        filter_ids = []
+
+                        if "info" in model and "meta" in model["info"]:
+                            action_ids.extend(
+                                model["info"]["meta"].get("actionIds", [])
+                            )
+                            filter_ids.extend(
+                                model["info"]["meta"].get("filterIds", [])
+                            )
+
+                        model["action_ids"] = action_ids
+                        model["filter_ids"] = filter_ids
+                    else:
+                        models.remove(model)
+
+        elif custom_model.is_active and (
+            custom_model.id not in [model["id"] for model in models]
+        ):
+            owned_by = "openai"
+            pipe = None
+
+            action_ids = []
+            filter_ids = []
+
+            for model in models:
+                if (
+                    custom_model.base_model_id == model["id"]
+                    or custom_model.base_model_id == model["id"].split(":")[0]
+                ):
+                    owned_by = model.get("owned_by", "unknown owner")
+                    if "pipe" in model:
+                        pipe = model["pipe"]
+                    break
+
+            if custom_model.meta:
+                meta = custom_model.meta.model_dump()
+
+                if "actionIds" in meta:
+                    action_ids.extend(meta["actionIds"])
 
-        # Process action_ids to get the actions
-        def get_action_items_from_module(function, module):
-            actions = []
-            if hasattr(module, "actions"):
-                actions = module.actions
-                return [
-                    {
-                        "id": f"{function.id}.{action['id']}",
-                        "name": action.get("name", f"{function.name} ({action['id']})"),
-                        "description": function.meta.description,
-                        "icon": action.get(
-                            "icon_url",
-                            function.meta.manifest.get("icon_url", None)
-                            or getattr(module, "icon_url", None)
-                            or getattr(module, "icon", None),
-                        ),
-                    }
-                    for action in actions
-                ]
-            else:
-                return [
-                    {
-                        "id": function.id,
-                        "name": function.name,
-                        "description": function.meta.description,
-                        "icon": function.meta.manifest.get("icon_url", None)
+                if "filterIds" in meta:
+                    filter_ids.extend(meta["filterIds"])
+
+            models.append(
+                {
+                    "id": f"{custom_model.id}",
+                    "name": custom_model.name,
+                    "object": "model",
+                    "created": custom_model.created_at,
+                    "owned_by": owned_by,
+                    "info": custom_model.model_dump(),
+                    "preset": True,
+                    **({"pipe": pipe} if pipe is not None else {}),
+                    "action_ids": action_ids,
+                    "filter_ids": filter_ids,
+                }
+            )
+
+    # Process action_ids to get the actions
+    def get_action_items_from_module(function, module):
+        actions = []
+        if hasattr(module, "actions"):
+            actions = module.actions
+            return [
+                {
+                    "id": f"{function.id}.{action['id']}",
+                    "name": action.get("name", f"{function.name} ({action['id']})"),
+                    "description": function.meta.description,
+                    "icon": action.get(
+                        "icon_url",
+                        function.meta.manifest.get("icon_url", None)
                         or getattr(module, "icon_url", None)
                         or getattr(module, "icon", None),
-                    }
-                ]
-
-        # Process filter_ids to get the filters
-        def get_filter_items_from_module(function, module):
+                    ),
+                }
+                for action in actions
+            ]
+        else:
             return [
                 {
                     "id": function.id,
@@ -258,54 +248,63 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None)
                 }
             ]
 
-        def get_function_module_by_id(function_id):
-            function_module, _, _ = get_function_module_from_cache(request, function_id)
-            return function_module
+    # Process filter_ids to get the filters
+    def get_filter_items_from_module(function, module):
+        return [
+            {
+                "id": function.id,
+                "name": function.name,
+                "description": function.meta.description,
+                "icon": function.meta.manifest.get("icon_url", None)
+                or getattr(module, "icon_url", None)
+                or getattr(module, "icon", None),
+            }
+        ]
 
-        for model in models:
-            action_ids = [
-                action_id
-                for action_id in list(
-                    set(model.pop("action_ids", []) + global_action_ids)
-                )
-                if action_id in enabled_action_ids
-            ]
-            filter_ids = [
-                filter_id
-                for filter_id in list(
-                    set(model.pop("filter_ids", []) + global_filter_ids)
-                )
-                if filter_id in enabled_filter_ids
-            ]
+    def get_function_module_by_id(function_id):
+        function_module, _, _ = get_function_module_from_cache(request, function_id)
+        return function_module
 
-            model["actions"] = []
-            for action_id in action_ids:
-                action_function = Functions.get_function_by_id(action_id)
-                if action_function is None:
-                    raise Exception(f"Action not found: {action_id}")
+    for model in models:
+        action_ids = [
+            action_id
+            for action_id in list(set(model.pop("action_ids", []) + global_action_ids))
+            if action_id in enabled_action_ids
+        ]
+        filter_ids = [
+            filter_id
+            for filter_id in list(set(model.pop("filter_ids", []) + global_filter_ids))
+            if filter_id in enabled_filter_ids
+        ]
 
-                function_module = get_function_module_by_id(action_id)
-                model["actions"].extend(
-                    get_action_items_from_module(action_function, function_module)
-                )
+        model["actions"] = []
+        for action_id in action_ids:
+            action_function = Functions.get_function_by_id(action_id)
+            if action_function is None:
+                raise Exception(f"Action not found: {action_id}")
 
-            model["filters"] = []
-            for filter_id in filter_ids:
-                filter_function = Functions.get_function_by_id(filter_id)
-                if filter_function is None:
-                    raise Exception(f"Filter not found: {filter_id}")
+            function_module = get_function_module_by_id(action_id)
+            model["actions"].extend(
+                get_action_items_from_module(action_function, function_module)
+            )
 
-                function_module = get_function_module_by_id(filter_id)
+        model["filters"] = []
+        for filter_id in filter_ids:
+            filter_function = Functions.get_function_by_id(filter_id)
+            if filter_function is None:
+                raise Exception(f"Filter not found: {filter_id}")
 
-                if getattr(function_module, "toggle", None):
-                    model["filters"].extend(
-                        get_filter_items_from_module(filter_function, function_module)
-                    )
+            function_module = get_function_module_by_id(filter_id)
+
+            if getattr(function_module, "toggle", None):
+                model["filters"].extend(
+                    get_filter_items_from_module(filter_function, function_module)
+                )
 
-        log.debug(f"get_all_models() returned {len(models)} models")
+    log.debug(f"get_all_models() returned {len(models)} models")
 
-        request.app.state.MODELS = {model["id"]: model for model in models}
-        return models
+    request.app.state.MODELS = {model["id"]: model for model in models}
+    return models
 
 
 def check_model_access(user, model):

+ 3 - 3
src/lib/components/admin/Settings/Connections.svelte

@@ -386,12 +386,12 @@
 
 				<div class="my-2">
 					<div class="flex justify-between items-center text-sm">
-						<div class=" text-xs font-medium">{$i18n.t('Cache Model List')}</div>
+						<div class=" text-xs font-medium">{$i18n.t('Cache Base Model List')}</div>
 
 						<div class="flex items-center">
 							<div class="">
 								<Switch
-									bind:state={connectionsConfig.ENABLE_MODEL_LIST_CACHE}
+									bind:state={connectionsConfig.ENABLE_BASE_MODELS_CACHE}
 									on:change={async () => {
 										updateConnectionsHandler();
 									}}
@@ -402,7 +402,7 @@
 
 					<div class="mt-1 text-xs text-gray-400 dark:text-gray-500">
 						{$i18n.t(
-							'Model List Cache speeds up access by fetching models only at startup or on settings save—faster, but may not show recent model changes.'
+							'Base Model List Cache speeds up access by fetching base models only at startup or on settings save—faster, but may not show recent base model changes.'
 						)}
 					</div>
 				</div>