Timothy Jaeryang Baek 3 недель назад
Родитель
Сommit
9738ddfd99
2 измененных файлов с 41 добавлено и 34 удалено
  1. 2 33
      backend/open_webui/main.py
  2. 39 1
      backend/open_webui/utils/models.py

+ 2 - 33
backend/open_webui/main.py

@@ -448,6 +448,7 @@ from open_webui.utils.models import (
     get_all_models,
     get_all_base_models,
     check_model_access,
+    get_filtered_models,
 )
 from open_webui.utils.chat import (
     generate_chat_completion as chat_completion_handler,
@@ -1291,33 +1292,6 @@ if audit_level != AuditLevel.NONE:
 async def get_models(
     request: Request, refresh: bool = False, user=Depends(get_verified_user)
 ):
-    def get_filtered_models(models, user):
-        filtered_models = []
-        for model in models:
-            if model.get("arena"):
-                if has_access(
-                    user.id,
-                    type="read",
-                    access_control=model.get("info", {})
-                    .get("meta", {})
-                    .get("access_control", {}),
-                ):
-                    filtered_models.append(model)
-                continue
-
-            model_info = Models.get_model_by_id(model["id"])
-            if model_info:
-                if (
-                    (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL)
-                    or user.id == model_info.user_id
-                    or has_access(
-                        user.id, type="read", access_control=model_info.access_control
-                    )
-                ):
-                    filtered_models.append(model)
-
-        return filtered_models
-
     all_models = await get_all_models(request, refresh=refresh, user=user)
 
     models = []
@@ -1353,12 +1327,7 @@ async def get_models(
             )
         )
 
-    # Filter out models that the user does not have access to
-    if (
-        user.role == "user"
-        or (user.role == "admin" and not BYPASS_ADMIN_ACCESS_CONTROL)
-    ) and not BYPASS_MODEL_ACCESS_CONTROL:
-        models = get_filtered_models(models, user)
+    models = get_filtered_models(models, user)
 
     log.debug(
         f"/api/models returned filtered models accessible to the user: {json.dumps([model.get('id') for model in models])}"

+ 39 - 1
backend/open_webui/utils/models.py

@@ -22,10 +22,11 @@ from open_webui.utils.access_control import has_access
 
 
 from open_webui.config import (
+    BYPASS_ADMIN_ACCESS_CONTROL,
     DEFAULT_ARENA_MODEL,
 )
 
-from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
+from open_webui.env import BYPASS_MODEL_ACCESS_CONTROL, SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
 from open_webui.models.users import UserModel
 
 
@@ -332,3 +333,40 @@ def check_model_access(user, model):
             )
         ):
             raise Exception("Model not found")
+
+
+def get_filtered_models(models, user):
+    # Filter out models that the user does not have access to
+    if (
+        user.role == "user"
+        or (user.role == "admin" and not BYPASS_ADMIN_ACCESS_CONTROL)
+    ) and not BYPASS_MODEL_ACCESS_CONTROL:
+        filtered_models = []
+        for model in models:
+            if model.get("arena"):
+                if has_access(
+                    user.id,
+                    type="read",
+                    access_control=model.get("info", {})
+                    .get("meta", {})
+                    .get("access_control", {}),
+                ):
+                    filtered_models.append(model)
+                continue
+
+            model_info = Models.get_model_by_id(model["id"])
+            if model_info:
+                if (
+                    (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL)
+                    or user.id == model_info.user_id
+                    or has_access(
+                        user.id,
+                        type="read",
+                        access_control=model_info.access_control,
+                    )
+                ):
+                    filtered_models.append(model)
+
+        return filtered_models
+    else:
+        return models