Browse Source

feat: WEBUI_AUTH_TRUSTED_GROUPS_HEADER

Timothy Jaeryang Baek 1 month ago
parent
commit
cce5f024bd

+ 4 - 0
backend/open_webui/env.py

@@ -349,6 +349,10 @@ WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
     "WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
 )
 WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None)
+WEBUI_AUTH_TRUSTED_GROUPS_HEADER = os.environ.get(
+    "WEBUI_AUTH_TRUSTED_GROUPS_HEADER", None
+)
+
 
 BYPASS_MODEL_ACCESS_CONTROL = (
     os.environ.get("BYPASS_MODEL_ACCESS_CONTROL", "False").lower() == "true"

+ 2 - 2
backend/open_webui/models/auths.py

@@ -159,8 +159,8 @@ class AuthsTable:
         except Exception:
             return False
 
-    def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]:
-        log.info(f"authenticate_user_by_trusted_header: {email}")
+    def authenticate_user_by_email(self, email: str) -> Optional[UserModel]:
+        log.info(f"authenticate_user_by_email: {email}")
         try:
             with get_db() as db:
                 auth = db.query(Auth).filter_by(email=email, active=True).first()

+ 38 - 0
backend/open_webui/models/groups.py

@@ -207,5 +207,43 @@ class GroupTable:
             except Exception:
                 return False
 
+    def sync_user_groups_by_group_names(
+        self, user_id: str, group_names: list[str]
+    ) -> bool:
+        with get_db() as db:
+            try:
+                groups = db.query(Group).filter(Group.name.in_(group_names)).all()
+                group_ids = [group.id for group in groups]
+
+                # Remove user from groups not in the new list
+                existing_groups = self.get_groups_by_member_id(user_id)
+
+                for group in existing_groups:
+                    if group.id not in group_ids:
+                        group.user_ids.remove(user_id)
+                        db.query(Group).filter_by(id=group.id).update(
+                            {
+                                "user_ids": group.user_ids,
+                                "updated_at": int(time.time()),
+                            }
+                        )
+
+                # Add user to new groups
+                for group in groups:
+                    if user_id not in group.user_ids:
+                        group.user_ids.append(user_id)
+                        db.query(Group).filter_by(id=group.id).update(
+                            {
+                                "user_ids": group.user_ids,
+                                "updated_at": int(time.time()),
+                            }
+                        )
+
+                db.commit()
+                return True
+            except Exception as e:
+                log.exception(e)
+                return False
+
 
 Groups = GroupTable()

+ 21 - 11
backend/open_webui/routers/auths.py

@@ -19,12 +19,14 @@ from open_webui.models.auths import (
     UserResponse,
 )
 from open_webui.models.users import Users
+from open_webui.models.groups import Groups
 
 from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
 from open_webui.env import (
     WEBUI_AUTH,
     WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
     WEBUI_AUTH_TRUSTED_NAME_HEADER,
+    WEBUI_AUTH_TRUSTED_GROUPS_HEADER,
     WEBUI_AUTH_COOKIE_SAME_SITE,
     WEBUI_AUTH_COOKIE_SECURE,
     WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
@@ -299,7 +301,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
                         500, detail="Internal error occurred during LDAP user creation."
                     )
 
-            user = Auths.authenticate_user_by_trusted_header(email)
+            user = Auths.authenticate_user_by_email(email)
 
             if user:
                 expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN)
@@ -363,21 +365,29 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
         if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
             raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
 
-        trusted_email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower()
-        trusted_name = trusted_email
+        email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower()
+        name = email
+
         if WEBUI_AUTH_TRUSTED_NAME_HEADER:
-            trusted_name = request.headers.get(
-                WEBUI_AUTH_TRUSTED_NAME_HEADER, trusted_email
-            )
-        if not Users.get_user_by_email(trusted_email.lower()):
+            name = request.headers.get(WEBUI_AUTH_TRUSTED_NAME_HEADER, email)
+
+        if not Users.get_user_by_email(email.lower()):
             await signup(
                 request,
                 response,
-                SignupForm(
-                    email=trusted_email, password=str(uuid.uuid4()), name=trusted_name
-                ),
+                SignupForm(email=email, password=str(uuid.uuid4()), name=name),
             )
-        user = Auths.authenticate_user_by_trusted_header(trusted_email)
+
+        user = Auths.authenticate_user_by_email(email)
+        if WEBUI_AUTH_TRUSTED_GROUPS_HEADER and user and user.role != "admin":
+            group_names = request.headers.get(
+                WEBUI_AUTH_TRUSTED_GROUPS_HEADER, ""
+            ).split(",")
+            group_names = [name.strip() for name in group_names if name.strip()]
+
+            if group_names:
+                Groups.sync_user_groups_by_group_names(user.id, group_names)
+
     elif WEBUI_AUTH == False:
         admin_email = "admin@localhost"
         admin_password = "admin"