Ver Fonte

Merge pull request #17166 from sihyeonn/perf/sh-model-layer

perf: fix N+1 query issues in user group access control validation
Tim Jaeryang Baek há 1 mês atrás
pai
commit
472b71f331

+ 3 - 1
backend/open_webui/models/knowledge.py

@@ -8,6 +8,7 @@ from open_webui.internal.db import Base, get_db
 from open_webui.env import SRC_LOG_LEVELS
 
 from open_webui.models.files import FileMetadataResponse
+from open_webui.models.groups import Groups
 from open_webui.models.users import Users, UserResponse
 
 
@@ -152,11 +153,12 @@ class KnowledgeTable:
         self, user_id: str, permission: str = "write"
     ) -> list[KnowledgeUserModel]:
         knowledge_bases = self.get_knowledge_bases()
+        user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
         return [
             knowledge_base
             for knowledge_base in knowledge_bases
             if knowledge_base.user_id == user_id
-            or has_access(user_id, permission, knowledge_base.access_control)
+            or has_access(user_id, permission, knowledge_base.access_control, user_group_ids)
         ]
 
     def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]:

+ 3 - 1
backend/open_webui/models/models.py

@@ -5,6 +5,7 @@ from typing import Optional
 from open_webui.internal.db import Base, JSONField, get_db
 from open_webui.env import SRC_LOG_LEVELS
 
+from open_webui.models.groups import Groups
 from open_webui.models.users import Users, UserResponse
 
 
@@ -206,11 +207,12 @@ class ModelsTable:
         self, user_id: str, permission: str = "write"
     ) -> list[ModelUserResponse]:
         models = self.get_models()
+        user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
         return [
             model
             for model in models
             if model.user_id == user_id
-            or has_access(user_id, permission, model.access_control)
+            or has_access(user_id, permission, model.access_control, user_group_ids)
         ]
 
     def get_model_by_id(self, id: str) -> Optional[ModelModel]:

+ 3 - 1
backend/open_webui/models/notes.py

@@ -4,6 +4,7 @@ import uuid
 from typing import Optional
 
 from open_webui.internal.db import Base, get_db
+from open_webui.models.groups import Groups
 from open_webui.utils.access_control import has_access
 from open_webui.models.users import Users, UserResponse
 
@@ -105,11 +106,12 @@ class NoteTable:
         self, user_id: str, permission: str = "write"
     ) -> list[NoteModel]:
         notes = self.get_notes()
+        user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
         return [
             note
             for note in notes
             if note.user_id == user_id
-            or has_access(user_id, permission, note.access_control)
+            or has_access(user_id, permission, note.access_control, user_group_ids)
         ]
 
     def get_note_by_id(self, id: str) -> Optional[NoteModel]:

+ 3 - 1
backend/open_webui/models/prompts.py

@@ -2,6 +2,7 @@ import time
 from typing import Optional
 
 from open_webui.internal.db import Base, get_db
+from open_webui.models.groups import Groups
 from open_webui.models.users import Users, UserResponse
 
 from pydantic import BaseModel, ConfigDict
@@ -128,12 +129,13 @@ class PromptsTable:
         self, user_id: str, permission: str = "write"
     ) -> list[PromptUserResponse]:
         prompts = self.get_prompts()
+        user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
 
         return [
             prompt
             for prompt in prompts
             if prompt.user_id == user_id
-            or has_access(user_id, permission, prompt.access_control)
+            or has_access(user_id, permission, prompt.access_control, user_group_ids)
         ]
 
     def update_prompt_by_command(

+ 2 - 1
backend/open_webui/models/tools.py

@@ -168,12 +168,13 @@ class ToolsTable:
         self, user_id: str, permission: str = "write"
     ) -> list[ToolUserModel]:
         tools = self.get_tools()
+        user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
 
         return [
             tool
             for tool in tools
             if tool.user_id == user_id
-            or has_access(user_id, permission, tool.access_control)
+            or has_access(user_id, permission, tool.access_control, user_group_ids)
         ]
 
     def get_tool_valves_by_id(self, id: str) -> Optional[dict]: