Timothy Jaeryang Baek преди 3 месеца
родител
ревизия
371bdd7afa
променени са 2 файла, в които са добавени 92 реда и са изтрити 55 реда
  1. 32 2
      backend/open_webui/models/groups.py
  2. 60 53
      backend/open_webui/routers/auths.py

+ 32 - 2
backend/open_webui/models/groups.py

@@ -207,9 +207,39 @@ class GroupTable:
             except Exception:
                 return False
 
-    def sync_user_groups_by_group_names(
+    def create_groups_by_group_names(
         self, user_id: str, group_names: list[str]
-    ) -> bool:
+    ) -> list[GroupModel]:
+
+        # check for existing groups
+        existing_groups = self.get_groups()
+        existing_group_names = {group.name for group in existing_groups}
+
+        new_groups = []
+
+        with get_db() as db:
+            for group_name in group_names:
+                if group_name not in existing_group_names:
+                    new_group = GroupModel(
+                        id=str(uuid.uuid4()),
+                        user_id=user_id,
+                        name=group_name,
+                        description="",
+                        created_at=int(time.time()),
+                        updated_at=int(time.time()),
+                    )
+                    try:
+                        result = Group(**new_group.model_dump())
+                        db.add(result)
+                        db.commit()
+                        db.refresh(result)
+                        new_groups.append(GroupModel.model_validate(result))
+                    except Exception as e:
+                        log.exception(e)
+                        continue
+            return new_groups
+
+    def sync_groups_by_group_names(self, user_id: str, group_names: list[str]) -> bool:
         with get_db() as db:
             try:
                 groups = db.query(Group).filter(Group.name.in_(group_names)).all()

+ 60 - 53
backend/open_webui/routers/auths.py

@@ -228,18 +228,23 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
         if not connection_app.bind():
             raise HTTPException(400, detail="Application account bind failed")
 
-        ENABLE_LDAP_GROUP_MANAGEMENT = request.app.state.config.ENABLE_LDAP_GROUP_MANAGEMENT
+        ENABLE_LDAP_GROUP_MANAGEMENT = (
+            request.app.state.config.ENABLE_LDAP_GROUP_MANAGEMENT
+        )
+        ENABLE_LDAP_GROUP_CREATION = request.app.state.config.ENABLE_LDAP_GROUP_CREATION
         LDAP_ATTRIBUTE_FOR_GROUPS = request.app.state.config.LDAP_ATTRIBUTE_FOR_GROUPS
-        
+
         search_attributes = [
             f"{LDAP_ATTRIBUTE_FOR_USERNAME}",
             f"{LDAP_ATTRIBUTE_FOR_MAIL}",
             "cn",
         ]
-        
+
         if ENABLE_LDAP_GROUP_MANAGEMENT:
             search_attributes.append(f"{LDAP_ATTRIBUTE_FOR_GROUPS}")
-            log.info(f"LDAP Group Management enabled. Adding {LDAP_ATTRIBUTE_FOR_GROUPS} to search attributes")
+            log.info(
+                f"LDAP Group Management enabled. Adding {LDAP_ATTRIBUTE_FOR_GROUPS} to search attributes"
+            )
 
         log.info(f"LDAP search attributes: {search_attributes}")
 
@@ -273,55 +278,64 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
         if ENABLE_LDAP_GROUP_MANAGEMENT and LDAP_ATTRIBUTE_FOR_GROUPS in entry:
             group_dns = entry[LDAP_ATTRIBUTE_FOR_GROUPS]
             log.info(f"LDAP raw group DNs for user {username}: {group_dns}")
-            
+
             if group_dns:
                 log.info(f"LDAP group_dns original: {group_dns}")
                 log.info(f"LDAP group_dns type: {type(group_dns)}")
                 log.info(f"LDAP group_dns length: {len(group_dns)}")
-                
-                if hasattr(group_dns, 'value'):
+
+                if hasattr(group_dns, "value"):
                     group_dns = group_dns.value
                     log.info(f"Extracted .value property: {group_dns}")
-                elif hasattr(group_dns, '__iter__') and not isinstance(group_dns, (str, bytes)):
+                elif hasattr(group_dns, "__iter__") and not isinstance(
+                    group_dns, (str, bytes)
+                ):
                     group_dns = list(group_dns)
                     log.info(f"Converted to list: {group_dns}")
-                elif not isinstance(group_dns, list):
-                    group_dns = [group_dns]
-                
+
                 if isinstance(group_dns, list):
                     group_dns = [str(item) for item in group_dns]
                 else:
                     group_dns = [str(group_dns)]
-                
-                log.info(f"LDAP group_dns after processing - type: {type(group_dns)}, length: {len(group_dns)}")
-                
-                for i, group_dn in enumerate(group_dns):
-                    group_dn_str = str(group_dn)
-                    log.info(f"Processing group DN #{i+1}: {group_dn_str}")
-                    
+
+                log.info(
+                    f"LDAP group_dns after processing - type: {type(group_dns)}, length: {len(group_dns)}"
+                )
+
+                for group_idx, group_dn in enumerate(group_dns):
+                    group_dn = str(group_dn)
+                    log.info(f"Processing group DN #{group_idx + 1}: {group_dn}")
+
                     try:
-                        cn_part = None
-                        dn_parts = group_dn_str.split(',')
-                        log.debug(f"DN parts: {dn_parts}")
-                        
-                        for part in dn_parts:
-                            part = part.strip()
-                            if part.upper().startswith('CN='):
-                                cn_part = part[3:]  
+                        group_cn = None
+
+                        for item in group_dn.split(","):
+                            item = item.strip()
+                            if item.upper().startswith("CN="):
+                                group_cn = item[3:]
                                 break
-                        
-                        if cn_part:
-                            user_groups.append(cn_part)
+
+                        if group_cn:
+                            user_groups.append(group_cn)
+
                         else:
-                            log.warning(f"Could not extract CN from group DN: {group_dn_str}")
+                            log.warning(
+                                f"Could not extract CN from group DN: {group_dn}"
+                            )
                     except Exception as e:
-                        log.warning(f"Failed to extract group name from DN {group_dn_str}: {e}")
-                
-                log.info(f"LDAP groups for user {username}: {user_groups} (total: {len(user_groups)})")
+                        log.warning(
+                            f"Failed to extract group name from DN {group_dn}: {e}"
+                        )
+
+                log.info(
+                    f"LDAP groups for user {username}: {user_groups} (total: {len(user_groups)})"
+                )
             else:
                 log.info(f"No groups found for user {username}")
         elif ENABLE_LDAP_GROUP_MANAGEMENT:
-            log.warning(f"LDAP Group Management enabled but {LDAP_ATTRIBUTE_FOR_GROUPS} attribute not found in user entry")
+            log.warning(
+                f"LDAP Group Management enabled but {LDAP_ATTRIBUTE_FOR_GROUPS} attribute not found in user entry"
+            )
 
         if username == form_data.user.lower():
             connection_user = Connection(
@@ -398,26 +412,19 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
                     user.id, request.app.state.config.USER_PERMISSIONS
                 )
 
-                if ENABLE_LDAP_GROUP_MANAGEMENT and user_groups and request.app.state.config.ENABLE_LDAP_GROUP_CREATION:
-                    from open_webui.models.groups import GroupForm
-                    existing_groups = Groups.get_groups()
-                    existing_group_names = [grp.name for grp in existing_groups]
-                    log.info(f"Existing groups: {existing_group_names}")
-                    
-                    for i, g in enumerate(user_groups):
-                        if not any(grp.name == g for grp in existing_groups):
-                            try:
-                                Groups.insert_new_group(user.id, GroupForm(name=g, description=f"{LDAP_SERVER_LABEL}"))
-                                log.info(f"Successfully created group '{g}'")
-                            except Exception as e:
-                                log.error(f"Failed to create group '{g}': {e}")
-                        else:
-                            log.info(f"Group {g} already exists")
+                if (
+                    user.role != "admin"
+                    and ENABLE_LDAP_GROUP_MANAGEMENT
+                    and user_groups
+                ):
+                    if ENABLE_LDAP_GROUP_CREATION:
+                        Groups.create_groups_by_group_names(user.id, user_groups)
 
-                if ENABLE_LDAP_GROUP_MANAGEMENT and user_groups and user.role != "admin":
                     try:
-                        Groups.sync_user_groups_by_group_names(user.id, user_groups)
-                        log.info(f"Successfully synced groups for user {user.id}: {user_groups}")
+                        Groups.sync_groups_by_group_names(user.id, user_groups)
+                        log.info(
+                            f"Successfully synced groups for user {user.id}: {user_groups}"
+                        )
                     except Exception as e:
                         log.error(f"Failed to sync groups for user {user.id}: {e}")
 
@@ -473,7 +480,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
             group_names = [name.strip() for name in group_names if name.strip()]
 
             if group_names:
-                Groups.sync_user_groups_by_group_names(user.id, group_names)
+                Groups.sync_groups_by_group_names(user.id, group_names)
 
     elif WEBUI_AUTH == False:
         admin_email = "admin@localhost"