Browse Source

refac/enh: group add/remove users endpoints

Timothy Jaeryang Baek 2 months ago
parent
commit
9634df4347
2 changed files with 104 additions and 1 deletions
  1. 53 1
      backend/open_webui/models/groups.py
  2. 51 0
      backend/open_webui/routers/groups.py

+ 53 - 1
backend/open_webui/models/groups.py

@@ -83,10 +83,14 @@ class GroupForm(BaseModel):
     permissions: Optional[dict] = None
 
 
-class GroupUpdateForm(GroupForm):
+class UserIdsForm(BaseModel):
     user_ids: Optional[list[str]] = None
 
 
+class GroupUpdateForm(GroupForm, UserIdsForm):
+    pass
+
+
 class GroupTable:
     def insert_new_group(
         self, user_id: str, form_data: GroupForm
@@ -275,5 +279,53 @@ class GroupTable:
                 log.exception(e)
                 return False
 
+    def add_users_to_group(
+        self, id: str, user_ids: Optional[list[str]] = None
+    ) -> Optional[GroupModel]:
+        try:
+            with get_db() as db:
+                group = db.query(Group).filter_by(id=id).first()
+                if not group:
+                    return None
+
+                if not group.user_ids:
+                    group.user_ids = []
+
+                for user_id in user_ids:
+                    if user_id not in group.user_ids:
+                        group.user_ids.append(user_id)
+
+                group.updated_at = int(time.time())
+                db.commit()
+                db.refresh(group)
+                return GroupModel.model_validate(group)
+        except Exception as e:
+            log.exception(e)
+            return None
+
+    def remove_users_from_group(
+        self, id: str, user_ids: Optional[list[str]] = None
+    ) -> Optional[GroupModel]:
+        try:
+            with get_db() as db:
+                group = db.query(Group).filter_by(id=id).first()
+                if not group:
+                    return None
+
+                if not group.user_ids:
+                    return GroupModel.model_validate(group)
+
+                for user_id in user_ids:
+                    if user_id in group.user_ids:
+                        group.user_ids.remove(user_id)
+
+                group.updated_at = int(time.time())
+                db.commit()
+                db.refresh(group)
+                return GroupModel.model_validate(group)
+        except Exception as e:
+            log.exception(e)
+            return None
+
 
 Groups = GroupTable()

+ 51 - 0
backend/open_webui/routers/groups.py

@@ -9,6 +9,7 @@ from open_webui.models.groups import (
     GroupForm,
     GroupUpdateForm,
     GroupResponse,
+    UserIdsForm,
 )
 
 from open_webui.config import CACHE_DIR
@@ -107,6 +108,56 @@ async def update_group_by_id(
         )
 
 
+############################
+# AddUserToGroupByUserIdAndGroupId
+############################
+
+
+@router.post("/id/{id}/users/add", response_model=Optional[GroupResponse])
+async def add_user_to_group(
+    id: str, form_data: UserIdsForm, user=Depends(get_admin_user)
+):
+    try:
+        if form_data.user_ids:
+            form_data.user_ids = Users.get_valid_user_ids(form_data.user_ids)
+
+        group = Groups.add_users_to_group(id, form_data.user_ids)
+        if group:
+            return group
+        else:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=ERROR_MESSAGES.DEFAULT("Error adding users to group"),
+            )
+    except Exception as e:
+        log.exception(f"Error adding users to group {id}: {e}")
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.DEFAULT(e),
+        )
+
+
+@router.post("/id/{id}/users/remove", response_model=Optional[GroupResponse])
+async def remove_users_from_group(
+    id: str, form_data: UserIdsForm, user=Depends(get_admin_user)
+):
+    try:
+        group = Groups.remove_users_from_group(id, form_data.user_ids)
+        if group:
+            return group
+        else:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=ERROR_MESSAGES.DEFAULT("Error removing users from group"),
+            )
+    except Exception as e:
+        log.exception(f"Error removing users from group {id}: {e}")
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.DEFAULT(e),
+        )
+
+
 ############################
 # DeleteGroupById
 ############################