Browse Source

refac: group members backend

Timothy Jaeryang Baek 3 months ago
parent
commit
bc576782d7

+ 158 - 66
backend/open_webui/models/groups.py

@@ -11,7 +11,7 @@ from open_webui.models.files import FileMetadataResponse
 
 
 from pydantic import BaseModel, ConfigDict
-from sqlalchemy import BigInteger, Column, String, Text, JSON, func
+from sqlalchemy import BigInteger, Column, String, Text, JSON, func, ForeignKey
 
 
 log = logging.getLogger(__name__)
@@ -35,7 +35,6 @@ class Group(Base):
     meta = Column(JSON, nullable=True)
 
     permissions = Column(JSON, nullable=True)
-    user_ids = Column(JSON, nullable=True)
 
     created_at = Column(BigInteger)
     updated_at = Column(BigInteger)
@@ -53,12 +52,33 @@ class GroupModel(BaseModel):
     meta: Optional[dict] = None
 
     permissions: Optional[dict] = None
-    user_ids: list[str] = []
 
     created_at: int  # timestamp in epoch
     updated_at: int  # timestamp in epoch
 
 
+class GroupMember(Base):
+    __tablename__ = "group_member"
+
+    id = Column(Text, unique=True, primary_key=True)
+    group_id = Column(
+        Text,
+        ForeignKey("group.id", ondelete="CASCADE"),
+        nullable=False,
+    )
+    user_id = Column(Text, nullable=False)
+    created_at = Column(BigInteger, nullable=True)
+    updated_at = Column(BigInteger, nullable=True)
+
+
+class GroupMemberModel(BaseModel):
+    id: str
+    group_id: str
+    user_id: str
+    created_at: Optional[int] = None  # timestamp in epoch
+    updated_at: Optional[int] = None  # timestamp in epoch
+
+
 ####################
 # Forms
 ####################
@@ -72,7 +92,7 @@ class GroupResponse(BaseModel):
     permissions: Optional[dict] = None
     data: Optional[dict] = None
     meta: Optional[dict] = None
-    user_ids: list[str] = []
+    member_count: Optional[int] = None
     created_at: int  # timestamp in epoch
     updated_at: int  # timestamp in epoch
 
@@ -87,7 +107,7 @@ class UserIdsForm(BaseModel):
     user_ids: Optional[list[str]] = None
 
 
-class GroupUpdateForm(GroupForm, UserIdsForm):
+class GroupUpdateForm(GroupForm):
     pass
 
 
@@ -131,12 +151,8 @@ class GroupTable:
             return [
                 GroupModel.model_validate(group)
                 for group in db.query(Group)
-                .filter(
-                    func.json_array_length(Group.user_ids) > 0
-                )  # Ensure array exists
-                .filter(
-                    Group.user_ids.cast(String).like(f'%"{user_id}"%')
-                )  # String-based check
+                .join(GroupMember, GroupMember.group_id == Group.id)
+                .filter(GroupMember.user_id == user_id)
                 .order_by(Group.updated_at.desc())
                 .all()
             ]
@@ -149,12 +165,46 @@ class GroupTable:
         except Exception:
             return None
 
-    def get_group_user_ids_by_id(self, id: str) -> Optional[str]:
-        group = self.get_group_by_id(id)
-        if group:
-            return group.user_ids
-        else:
-            return None
+    def get_group_user_ids_by_id(self, id: str) -> Optional[list[str]]:
+        with get_db() as db:
+            members = (
+                db.query(GroupMember.user_id).filter(GroupMember.group_id == id).all()
+            )
+
+            if not members:
+                return None
+
+            return [m[0] for m in members]
+
+    def set_group_user_ids_by_id(self, group_id: str, user_ids: list[str]) -> None:
+        with get_db() as db:
+            # Delete existing members
+            db.query(GroupMember).filter(GroupMember.group_id == group_id).delete()
+
+            # Insert new members
+            now = int(time.time())
+            new_members = [
+                GroupMember(
+                    id=str(uuid.uuid4()),
+                    group_id=group_id,
+                    user_id=user_id,
+                    created_at=now,
+                    updated_at=now,
+                )
+                for user_id in user_ids
+            ]
+
+            db.add_all(new_members)
+            db.commit()
+
+    def get_group_member_count_by_id(self, id: str) -> int:
+        with get_db() as db:
+            count = (
+                db.query(func.count(GroupMember.user_id))
+                .filter(GroupMember.group_id == id)
+                .scalar()
+            )
+            return count if count else 0
 
     def update_group_by_id(
         self, id: str, form_data: GroupUpdateForm, overwrite: bool = False
@@ -195,20 +245,29 @@ class GroupTable:
     def remove_user_from_all_groups(self, user_id: str) -> bool:
         with get_db() as db:
             try:
-                groups = self.get_groups_by_member_id(user_id)
+                # Find all groups the user belongs to
+                groups = (
+                    db.query(Group)
+                    .join(GroupMember, GroupMember.group_id == Group.id)
+                    .filter(GroupMember.user_id == user_id)
+                    .all()
+                )
 
+                # Remove the user from each group
                 for group in groups:
-                    group.user_ids.remove(user_id)
+                    db.query(GroupMember).filter(
+                        GroupMember.group_id == group.id, GroupMember.user_id == user_id
+                    ).delete()
+
                     db.query(Group).filter_by(id=group.id).update(
-                        {
-                            "user_ids": group.user_ids,
-                            "updated_at": int(time.time()),
-                        }
+                        {"updated_at": int(time.time())}
                     )
-                    db.commit()
 
+                db.commit()
                 return True
+
             except Exception:
+                db.rollback()
                 return False
 
     def create_groups_by_group_names(
@@ -246,37 +305,61 @@ class GroupTable:
     def sync_groups_by_group_names(self, user_id: str, group_names: list[str]) -> bool:
         with get_db() as db:
             try:
-                groups = db.query(Group).filter(Group.name.in_(group_names)).all()
-                group_ids = [group.id for group in groups]
-
-                # Remove user from groups not in the new list
-                existing_groups = self.get_groups_by_member_id(user_id)
-
-                for group in existing_groups:
-                    if group.id not in group_ids:
-                        group.user_ids.remove(user_id)
-                        db.query(Group).filter_by(id=group.id).update(
-                            {
-                                "user_ids": group.user_ids,
-                                "updated_at": int(time.time()),
-                            }
-                        )
+                now = int(time.time())
 
-                # Add user to new groups
-                for group in groups:
-                    if user_id not in group.user_ids:
-                        group.user_ids.append(user_id)
-                        db.query(Group).filter_by(id=group.id).update(
-                            {
-                                "user_ids": group.user_ids,
-                                "updated_at": int(time.time()),
-                            }
+                # 1. Groups that SHOULD contain the user
+                target_groups = (
+                    db.query(Group).filter(Group.name.in_(group_names)).all()
+                )
+                target_group_ids = {g.id for g in target_groups}
+
+                # 2. Groups the user is CURRENTLY in
+                existing_group_ids = {
+                    g.id
+                    for g in db.query(Group)
+                    .join(GroupMember, GroupMember.group_id == Group.id)
+                    .filter(GroupMember.user_id == user_id)
+                    .all()
+                }
+
+                # 3. Determine adds + removals
+                groups_to_add = target_group_ids - existing_group_ids
+                groups_to_remove = existing_group_ids - target_group_ids
+
+                # 4. Remove in one bulk delete
+                if groups_to_remove:
+                    db.query(GroupMember).filter(
+                        GroupMember.user_id == user_id,
+                        GroupMember.group_id.in_(groups_to_remove),
+                    ).delete(synchronize_session=False)
+
+                    db.query(Group).filter(Group.id.in_(groups_to_remove)).update(
+                        {"updated_at": now}, synchronize_session=False
+                    )
+
+                # 5. Bulk insert missing memberships
+                for group_id in groups_to_add:
+                    db.add(
+                        GroupMember(
+                            id=str(uuid.uuid4()),
+                            group_id=group_id,
+                            user_id=user_id,
+                            created_at=now,
+                            updated_at=now,
                         )
+                    )
+
+                if groups_to_add:
+                    db.query(Group).filter(Group.id.in_(groups_to_add)).update(
+                        {"updated_at": now}, synchronize_session=False
+                    )
 
                 db.commit()
                 return True
+
             except Exception as e:
                 log.exception(e)
+                db.rollback()
                 return False
 
     def add_users_to_group(
@@ -288,21 +371,31 @@ class GroupTable:
                 if not group:
                     return None
 
-                group_user_ids = group.user_ids
-                if not group_user_ids or not isinstance(group_user_ids, list):
-                    group_user_ids = []
-
-                group_user_ids = list(set(group_user_ids))  # Deduplicate
+                now = int(time.time())
 
-                for user_id in user_ids:
-                    if user_id not in group_user_ids:
-                        group_user_ids.append(user_id)
+                for user_id in user_ids or []:
+                    try:
+                        db.add(
+                            GroupMember(
+                                id=str(uuid.uuid4()),
+                                group_id=id,
+                                user_id=user_id,
+                                created_at=now,
+                                updated_at=now,
+                            )
+                        )
+                        db.flush()  # Detect unique constraint violation early
+                    except Exception:
+                        db.rollback()  # Clear failed INSERT
+                        db.begin()  # Start a new transaction
+                        continue  # Duplicate → ignore
 
-                group.user_ids = group_user_ids
-                group.updated_at = int(time.time())
+                group.updated_at = now
                 db.commit()
                 db.refresh(group)
+
                 return GroupModel.model_validate(group)
+
         except Exception as e:
             log.exception(e)
             return None
@@ -316,23 +409,22 @@ class GroupTable:
                 if not group:
                     return None
 
-                group_user_ids = group.user_ids
-
-                if not group_user_ids or not isinstance(group_user_ids, list):
+                if not user_ids:
                     return GroupModel.model_validate(group)
 
-                group_user_ids = list(set(group_user_ids))  # Deduplicate
-
+                # Remove each user from group_member
                 for user_id in user_ids:
-                    if user_id in group_user_ids:
-                        group_user_ids.remove(user_id)
+                    db.query(GroupMember).filter(
+                        GroupMember.group_id == id, GroupMember.user_id == user_id
+                    ).delete()
 
-                group.user_ids = group_user_ids
+                # Update group timestamp
                 group.updated_at = int(time.time())
 
                 db.commit()
                 db.refresh(group)
                 return GroupModel.model_validate(group)
+
         except Exception as e:
             log.exception(e)
             return None

+ 31 - 10
backend/open_webui/routers/groups.py

@@ -33,9 +33,18 @@ router = APIRouter()
 @router.get("/", response_model=list[GroupResponse])
 async def get_groups(user=Depends(get_verified_user)):
     if user.role == "admin":
-        return Groups.get_groups()
+        groups = Groups.get_groups()
     else:
-        return Groups.get_groups_by_member_id(user.id)
+        groups = Groups.get_groups_by_member_id(user.id)
+
+    return [
+        GroupResponse(
+            **group.model_dump(),
+            member_count=Groups.get_group_member_count_by_id(group.id),
+        )
+        for group in groups
+        if group
+    ]
 
 
 ############################
@@ -48,7 +57,10 @@ async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)):
     try:
         group = Groups.insert_new_group(user.id, form_data)
         if group:
-            return group
+            return GroupResponse(
+                **group.model_dump(),
+                member_count=Groups.get_group_member_count_by_id(group.id),
+            )
         else:
             raise HTTPException(
                 status_code=status.HTTP_400_BAD_REQUEST,
@@ -71,7 +83,10 @@ async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)):
 async def get_group_by_id(id: str, user=Depends(get_admin_user)):
     group = Groups.get_group_by_id(id)
     if group:
-        return group
+        return GroupResponse(
+            **group.model_dump(),
+            member_count=Groups.get_group_member_count_by_id(group.id),
+        )
     else:
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,
@@ -89,12 +104,12 @@ async def update_group_by_id(
     id: str, form_data: GroupUpdateForm, user=Depends(get_admin_user)
 ):
     try:
-        if form_data.user_ids:
-            form_data.user_ids = Users.get_valid_user_ids(form_data.user_ids)
-
         group = Groups.update_group_by_id(id, form_data)
         if group:
-            return group
+            return GroupResponse(
+                **group.model_dump(),
+                member_count=Groups.get_group_member_count_by_id(group.id),
+            )
         else:
             raise HTTPException(
                 status_code=status.HTTP_400_BAD_REQUEST,
@@ -123,7 +138,10 @@ async def add_user_to_group(
 
         group = Groups.add_users_to_group(id, form_data.user_ids)
         if group:
-            return group
+            return GroupResponse(
+                **group.model_dump(),
+                member_count=Groups.get_group_member_count_by_id(group.id),
+            )
         else:
             raise HTTPException(
                 status_code=status.HTTP_400_BAD_REQUEST,
@@ -144,7 +162,10 @@ async def remove_users_from_group(
     try:
         group = Groups.remove_users_from_group(id, form_data.user_ids)
         if group:
-            return group
+            return GroupResponse(
+                **group.model_dump(),
+                member_count=Groups.get_group_member_count_by_id(group.id),
+            )
         else:
             raise HTTPException(
                 status_code=status.HTTP_400_BAD_REQUEST,

+ 13 - 9
backend/open_webui/routers/scim.py

@@ -349,8 +349,10 @@ def user_to_scim(user: UserModel, request: Request) -> SCIMUser:
 
 def group_to_scim(group: GroupModel, request: Request) -> SCIMGroup:
     """Convert internal Group model to SCIM Group"""
+    member_ids = Groups.get_group_user_ids_by_id(group.id)
     members = []
-    for user_id in group.user_ids:
+
+    for user_id in member_ids:
         user = Users.get_user_by_id(user_id)
         if user:
             members.append(
@@ -796,9 +798,11 @@ async def create_group(
         update_form = GroupUpdateForm(
             name=new_group.name,
             description=new_group.description,
-            user_ids=member_ids,
         )
+
         Groups.update_group_by_id(new_group.id, update_form)
+        Groups.set_group_user_ids_by_id(new_group.id, member_ids)
+
         new_group = Groups.get_group_by_id(new_group.id)
 
     return group_to_scim(new_group, request)
@@ -830,7 +834,7 @@ async def update_group(
     # Handle members if provided
     if group_data.members is not None:
         member_ids = [member.value for member in group_data.members]
-        update_form.user_ids = member_ids
+        Groups.set_group_user_ids_by_id(group_id, member_ids)
 
     # Update group
     updated_group = Groups.update_group_by_id(group_id, update_form)
@@ -863,7 +867,6 @@ async def patch_group(
     update_form = GroupUpdateForm(
         name=group.name,
         description=group.description,
-        user_ids=group.user_ids.copy() if group.user_ids else [],
     )
 
     for operation in patch_data.Operations:
@@ -876,21 +879,22 @@ async def patch_group(
                 update_form.name = value
             elif path == "members":
                 # Replace all members
-                update_form.user_ids = [member["value"] for member in value]
+                Groups.set_group_user_ids_by_id(
+                    group_id, [member["value"] for member in value]
+                )
+
         elif op == "add":
             if path == "members":
                 # Add members
                 if isinstance(value, list):
                     for member in value:
                         if isinstance(member, dict) and "value" in member:
-                            if member["value"] not in update_form.user_ids:
-                                update_form.user_ids.append(member["value"])
+                            Groups.add_users_to_group(group_id, [member["value"]])
         elif op == "remove":
             if path and path.startswith("members[value eq"):
                 # Remove specific member
                 member_id = path.split('"')[1]
-                if member_id in update_form.user_ids:
-                    update_form.user_ids.remove(member_id)
+                Groups.remove_users_from_group(group_id, [member_id])
 
     # Update group
     updated_group = Groups.update_group_by_id(group_id, update_form)

+ 16 - 18
backend/open_webui/utils/oauth.py

@@ -1130,22 +1130,21 @@ class OAuthManager:
                     f"Removing user from group {group_model.name} as it is no longer in their oauth groups"
                 )
 
-                user_ids = group_model.user_ids
-                user_ids = [i for i in user_ids if i != user.id]
+                Groups.remove_users_from_group(group_model.id, [user.id])
 
                 # In case a group is created, but perms are never assigned to the group by hitting "save"
                 group_permissions = group_model.permissions
                 if not group_permissions:
                     group_permissions = default_permissions
 
-                update_form = GroupUpdateForm(
-                    name=group_model.name,
-                    description=group_model.description,
-                    permissions=group_permissions,
-                    user_ids=user_ids,
-                )
                 Groups.update_group_by_id(
-                    id=group_model.id, form_data=update_form, overwrite=False
+                    id=group_model.id,
+                    form_data=GroupUpdateForm(
+                        name=group_model.name,
+                        description=group_model.description,
+                        permissions=group_permissions,
+                    ),
+                    overwrite=False,
                 )
 
         # Add user to new groups
@@ -1161,22 +1160,21 @@ class OAuthManager:
                     f"Adding user to group {group_model.name} as it was found in their oauth groups"
                 )
 
-                user_ids = group_model.user_ids
-                user_ids.append(user.id)
+                Groups.add_users_to_group(group_model.id, [user.id])
 
                 # In case a group is created, but perms are never assigned to the group by hitting "save"
                 group_permissions = group_model.permissions
                 if not group_permissions:
                     group_permissions = default_permissions
 
-                update_form = GroupUpdateForm(
-                    name=group_model.name,
-                    description=group_model.description,
-                    permissions=group_permissions,
-                    user_ids=user_ids,
-                )
                 Groups.update_group_by_id(
-                    id=group_model.id, form_data=update_form, overwrite=False
+                    id=group_model.id,
+                    form_data=GroupUpdateForm(
+                        name=group_model.name,
+                        description=group_model.description,
+                        permissions=group_permissions,
+                    ),
+                    overwrite=False,
                 )
 
     async def _process_picture_url(