|
|
@@ -11,7 +11,7 @@ from open_webui.models.files import FileMetadataResponse
|
|
|
|
|
|
|
|
|
from pydantic import BaseModel, ConfigDict
|
|
|
-from sqlalchemy import BigInteger, Column, String, Text, JSON, func
|
|
|
+from sqlalchemy import BigInteger, Column, String, Text, JSON, func, ForeignKey
|
|
|
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
@@ -35,7 +35,6 @@ class Group(Base):
|
|
|
meta = Column(JSON, nullable=True)
|
|
|
|
|
|
permissions = Column(JSON, nullable=True)
|
|
|
- user_ids = Column(JSON, nullable=True)
|
|
|
|
|
|
created_at = Column(BigInteger)
|
|
|
updated_at = Column(BigInteger)
|
|
|
@@ -53,12 +52,33 @@ class GroupModel(BaseModel):
|
|
|
meta: Optional[dict] = None
|
|
|
|
|
|
permissions: Optional[dict] = None
|
|
|
- user_ids: list[str] = []
|
|
|
|
|
|
created_at: int # timestamp in epoch
|
|
|
updated_at: int # timestamp in epoch
|
|
|
|
|
|
|
|
|
+class GroupMember(Base):
|
|
|
+ __tablename__ = "group_member"
|
|
|
+
|
|
|
+ id = Column(Text, unique=True, primary_key=True)
|
|
|
+ group_id = Column(
|
|
|
+ Text,
|
|
|
+ ForeignKey("group.id", ondelete="CASCADE"),
|
|
|
+ nullable=False,
|
|
|
+ )
|
|
|
+ user_id = Column(Text, nullable=False)
|
|
|
+ created_at = Column(BigInteger, nullable=True)
|
|
|
+ updated_at = Column(BigInteger, nullable=True)
|
|
|
+
|
|
|
+
|
|
|
+class GroupMemberModel(BaseModel):
|
|
|
+ id: str
|
|
|
+ group_id: str
|
|
|
+ user_id: str
|
|
|
+ created_at: Optional[int] = None # timestamp in epoch
|
|
|
+ updated_at: Optional[int] = None # timestamp in epoch
|
|
|
+
|
|
|
+
|
|
|
####################
|
|
|
# Forms
|
|
|
####################
|
|
|
@@ -72,7 +92,7 @@ class GroupResponse(BaseModel):
|
|
|
permissions: Optional[dict] = None
|
|
|
data: Optional[dict] = None
|
|
|
meta: Optional[dict] = None
|
|
|
- user_ids: list[str] = []
|
|
|
+ member_count: Optional[int] = None
|
|
|
created_at: int # timestamp in epoch
|
|
|
updated_at: int # timestamp in epoch
|
|
|
|
|
|
@@ -87,7 +107,7 @@ class UserIdsForm(BaseModel):
|
|
|
user_ids: Optional[list[str]] = None
|
|
|
|
|
|
|
|
|
-class GroupUpdateForm(GroupForm, UserIdsForm):
|
|
|
+class GroupUpdateForm(GroupForm):
|
|
|
pass
|
|
|
|
|
|
|
|
|
@@ -131,12 +151,8 @@ class GroupTable:
|
|
|
return [
|
|
|
GroupModel.model_validate(group)
|
|
|
for group in db.query(Group)
|
|
|
- .filter(
|
|
|
- func.json_array_length(Group.user_ids) > 0
|
|
|
- ) # Ensure array exists
|
|
|
- .filter(
|
|
|
- Group.user_ids.cast(String).like(f'%"{user_id}"%')
|
|
|
- ) # String-based check
|
|
|
+ .join(GroupMember, GroupMember.group_id == Group.id)
|
|
|
+ .filter(GroupMember.user_id == user_id)
|
|
|
.order_by(Group.updated_at.desc())
|
|
|
.all()
|
|
|
]
|
|
|
@@ -149,12 +165,46 @@ class GroupTable:
|
|
|
except Exception:
|
|
|
return None
|
|
|
|
|
|
- def get_group_user_ids_by_id(self, id: str) -> Optional[str]:
|
|
|
- group = self.get_group_by_id(id)
|
|
|
- if group:
|
|
|
- return group.user_ids
|
|
|
- else:
|
|
|
- return None
|
|
|
+ def get_group_user_ids_by_id(self, id: str) -> Optional[list[str]]:
|
|
|
+ with get_db() as db:
|
|
|
+ members = (
|
|
|
+ db.query(GroupMember.user_id).filter(GroupMember.group_id == id).all()
|
|
|
+ )
|
|
|
+
|
|
|
+ if not members:
|
|
|
+ return None
|
|
|
+
|
|
|
+ return [m[0] for m in members]
|
|
|
+
|
|
|
+ def set_group_user_ids_by_id(self, group_id: str, user_ids: list[str]) -> None:
|
|
|
+ with get_db() as db:
|
|
|
+ # Delete existing members
|
|
|
+ db.query(GroupMember).filter(GroupMember.group_id == group_id).delete()
|
|
|
+
|
|
|
+ # Insert new members
|
|
|
+ now = int(time.time())
|
|
|
+ new_members = [
|
|
|
+ GroupMember(
|
|
|
+ id=str(uuid.uuid4()),
|
|
|
+ group_id=group_id,
|
|
|
+ user_id=user_id,
|
|
|
+ created_at=now,
|
|
|
+ updated_at=now,
|
|
|
+ )
|
|
|
+ for user_id in user_ids
|
|
|
+ ]
|
|
|
+
|
|
|
+ db.add_all(new_members)
|
|
|
+ db.commit()
|
|
|
+
|
|
|
+ def get_group_member_count_by_id(self, id: str) -> int:
|
|
|
+ with get_db() as db:
|
|
|
+ count = (
|
|
|
+ db.query(func.count(GroupMember.user_id))
|
|
|
+ .filter(GroupMember.group_id == id)
|
|
|
+ .scalar()
|
|
|
+ )
|
|
|
+ return count if count else 0
|
|
|
|
|
|
def update_group_by_id(
|
|
|
self, id: str, form_data: GroupUpdateForm, overwrite: bool = False
|
|
|
@@ -195,20 +245,29 @@ class GroupTable:
|
|
|
def remove_user_from_all_groups(self, user_id: str) -> bool:
|
|
|
with get_db() as db:
|
|
|
try:
|
|
|
- groups = self.get_groups_by_member_id(user_id)
|
|
|
+ # Find all groups the user belongs to
|
|
|
+ groups = (
|
|
|
+ db.query(Group)
|
|
|
+ .join(GroupMember, GroupMember.group_id == Group.id)
|
|
|
+ .filter(GroupMember.user_id == user_id)
|
|
|
+ .all()
|
|
|
+ )
|
|
|
|
|
|
+ # Remove the user from each group
|
|
|
for group in groups:
|
|
|
- group.user_ids.remove(user_id)
|
|
|
+ db.query(GroupMember).filter(
|
|
|
+ GroupMember.group_id == group.id, GroupMember.user_id == user_id
|
|
|
+ ).delete()
|
|
|
+
|
|
|
db.query(Group).filter_by(id=group.id).update(
|
|
|
- {
|
|
|
- "user_ids": group.user_ids,
|
|
|
- "updated_at": int(time.time()),
|
|
|
- }
|
|
|
+ {"updated_at": int(time.time())}
|
|
|
)
|
|
|
- db.commit()
|
|
|
|
|
|
+ db.commit()
|
|
|
return True
|
|
|
+
|
|
|
except Exception:
|
|
|
+ db.rollback()
|
|
|
return False
|
|
|
|
|
|
def create_groups_by_group_names(
|
|
|
@@ -246,37 +305,61 @@ class GroupTable:
|
|
|
def sync_groups_by_group_names(self, user_id: str, group_names: list[str]) -> bool:
|
|
|
with get_db() as db:
|
|
|
try:
|
|
|
- groups = db.query(Group).filter(Group.name.in_(group_names)).all()
|
|
|
- group_ids = [group.id for group in groups]
|
|
|
-
|
|
|
- # Remove user from groups not in the new list
|
|
|
- existing_groups = self.get_groups_by_member_id(user_id)
|
|
|
-
|
|
|
- for group in existing_groups:
|
|
|
- if group.id not in group_ids:
|
|
|
- group.user_ids.remove(user_id)
|
|
|
- db.query(Group).filter_by(id=group.id).update(
|
|
|
- {
|
|
|
- "user_ids": group.user_ids,
|
|
|
- "updated_at": int(time.time()),
|
|
|
- }
|
|
|
- )
|
|
|
+ now = int(time.time())
|
|
|
|
|
|
- # Add user to new groups
|
|
|
- for group in groups:
|
|
|
- if user_id not in group.user_ids:
|
|
|
- group.user_ids.append(user_id)
|
|
|
- db.query(Group).filter_by(id=group.id).update(
|
|
|
- {
|
|
|
- "user_ids": group.user_ids,
|
|
|
- "updated_at": int(time.time()),
|
|
|
- }
|
|
|
+ # 1. Groups that SHOULD contain the user
|
|
|
+ target_groups = (
|
|
|
+ db.query(Group).filter(Group.name.in_(group_names)).all()
|
|
|
+ )
|
|
|
+ target_group_ids = {g.id for g in target_groups}
|
|
|
+
|
|
|
+ # 2. Groups the user is CURRENTLY in
|
|
|
+ existing_group_ids = {
|
|
|
+ g.id
|
|
|
+ for g in db.query(Group)
|
|
|
+ .join(GroupMember, GroupMember.group_id == Group.id)
|
|
|
+ .filter(GroupMember.user_id == user_id)
|
|
|
+ .all()
|
|
|
+ }
|
|
|
+
|
|
|
+ # 3. Determine adds + removals
|
|
|
+ groups_to_add = target_group_ids - existing_group_ids
|
|
|
+ groups_to_remove = existing_group_ids - target_group_ids
|
|
|
+
|
|
|
+ # 4. Remove in one bulk delete
|
|
|
+ if groups_to_remove:
|
|
|
+ db.query(GroupMember).filter(
|
|
|
+ GroupMember.user_id == user_id,
|
|
|
+ GroupMember.group_id.in_(groups_to_remove),
|
|
|
+ ).delete(synchronize_session=False)
|
|
|
+
|
|
|
+ db.query(Group).filter(Group.id.in_(groups_to_remove)).update(
|
|
|
+ {"updated_at": now}, synchronize_session=False
|
|
|
+ )
|
|
|
+
|
|
|
+ # 5. Bulk insert missing memberships
|
|
|
+ for group_id in groups_to_add:
|
|
|
+ db.add(
|
|
|
+ GroupMember(
|
|
|
+ id=str(uuid.uuid4()),
|
|
|
+ group_id=group_id,
|
|
|
+ user_id=user_id,
|
|
|
+ created_at=now,
|
|
|
+ updated_at=now,
|
|
|
)
|
|
|
+ )
|
|
|
+
|
|
|
+ if groups_to_add:
|
|
|
+ db.query(Group).filter(Group.id.in_(groups_to_add)).update(
|
|
|
+ {"updated_at": now}, synchronize_session=False
|
|
|
+ )
|
|
|
|
|
|
db.commit()
|
|
|
return True
|
|
|
+
|
|
|
except Exception as e:
|
|
|
log.exception(e)
|
|
|
+ db.rollback()
|
|
|
return False
|
|
|
|
|
|
def add_users_to_group(
|
|
|
@@ -288,21 +371,31 @@ class GroupTable:
|
|
|
if not group:
|
|
|
return None
|
|
|
|
|
|
- group_user_ids = group.user_ids
|
|
|
- if not group_user_ids or not isinstance(group_user_ids, list):
|
|
|
- group_user_ids = []
|
|
|
-
|
|
|
- group_user_ids = list(set(group_user_ids)) # Deduplicate
|
|
|
+ now = int(time.time())
|
|
|
|
|
|
- for user_id in user_ids:
|
|
|
- if user_id not in group_user_ids:
|
|
|
- group_user_ids.append(user_id)
|
|
|
+ for user_id in user_ids or []:
|
|
|
+ try:
|
|
|
+ db.add(
|
|
|
+ GroupMember(
|
|
|
+ id=str(uuid.uuid4()),
|
|
|
+ group_id=id,
|
|
|
+ user_id=user_id,
|
|
|
+ created_at=now,
|
|
|
+ updated_at=now,
|
|
|
+ )
|
|
|
+ )
|
|
|
+ db.flush() # Detect unique constraint violation early
|
|
|
+ except Exception:
|
|
|
+ db.rollback() # Clear failed INSERT
|
|
|
+ db.begin() # Start a new transaction
|
|
|
+ continue # Duplicate → ignore
|
|
|
|
|
|
- group.user_ids = group_user_ids
|
|
|
- group.updated_at = int(time.time())
|
|
|
+ group.updated_at = now
|
|
|
db.commit()
|
|
|
db.refresh(group)
|
|
|
+
|
|
|
return GroupModel.model_validate(group)
|
|
|
+
|
|
|
except Exception as e:
|
|
|
log.exception(e)
|
|
|
return None
|
|
|
@@ -316,23 +409,22 @@ class GroupTable:
|
|
|
if not group:
|
|
|
return None
|
|
|
|
|
|
- group_user_ids = group.user_ids
|
|
|
-
|
|
|
- if not group_user_ids or not isinstance(group_user_ids, list):
|
|
|
+ if not user_ids:
|
|
|
return GroupModel.model_validate(group)
|
|
|
|
|
|
- group_user_ids = list(set(group_user_ids)) # Deduplicate
|
|
|
-
|
|
|
+ # Remove each user from group_member
|
|
|
for user_id in user_ids:
|
|
|
- if user_id in group_user_ids:
|
|
|
- group_user_ids.remove(user_id)
|
|
|
+ db.query(GroupMember).filter(
|
|
|
+ GroupMember.group_id == id, GroupMember.user_id == user_id
|
|
|
+ ).delete()
|
|
|
|
|
|
- group.user_ids = group_user_ids
|
|
|
+ # Update group timestamp
|
|
|
group.updated_at = int(time.time())
|
|
|
|
|
|
db.commit()
|
|
|
db.refresh(group)
|
|
|
return GroupModel.model_validate(group)
|
|
|
+
|
|
|
except Exception as e:
|
|
|
log.exception(e)
|
|
|
return None
|