Browse Source

refac: has_users

Co-Authored-By: pickle-dice <159401444+hassan-ajek@users.noreply.github.com>
Timothy Jaeryang Baek 2 months ago
parent
commit
f24b76d9a3

+ 4 - 0
backend/open_webui/models/users.py

@@ -258,6 +258,10 @@ class UsersTable:
         with get_db() as db:
         with get_db() as db:
             return db.query(User).count()
             return db.query(User).count()
 
 
+    def has_users(self) -> bool:
+        with get_db() as db:
+            return db.query(db.query(User).exists()).scalar()
+
     def get_first_user(self) -> UserModel:
     def get_first_user(self) -> UserModel:
         try:
         try:
             with get_db() as db:
             with get_db() as db:

+ 6 - 10
backend/open_webui/routers/auths.py

@@ -351,11 +351,9 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
             user = Users.get_user_by_email(email)
             user = Users.get_user_by_email(email)
             if not user:
             if not user:
                 try:
                 try:
-                    user_count = Users.get_num_users()
-
                     role = (
                     role = (
                         "admin"
                         "admin"
-                        if user_count == 0
+                        if not Users.has_users()
                         else request.app.state.config.DEFAULT_USER_ROLE
                         else request.app.state.config.DEFAULT_USER_ROLE
                     )
                     )
 
 
@@ -489,7 +487,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
         if Users.get_user_by_email(admin_email.lower()):
         if Users.get_user_by_email(admin_email.lower()):
             user = Auths.authenticate_user(admin_email.lower(), admin_password)
             user = Auths.authenticate_user(admin_email.lower(), admin_password)
         else:
         else:
-            if Users.get_num_users() != 0:
+            if Users.has_users():
                 raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS)
                 raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS)
 
 
             await signup(
             await signup(
@@ -556,6 +554,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
 
 
 @router.post("/signup", response_model=SessionUserResponse)
 @router.post("/signup", response_model=SessionUserResponse)
 async def signup(request: Request, response: Response, form_data: SignupForm):
 async def signup(request: Request, response: Response, form_data: SignupForm):
+    has_users = Users.has_users()
 
 
     if WEBUI_AUTH:
     if WEBUI_AUTH:
         if (
         if (
@@ -566,12 +565,11 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
                 status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
                 status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
             )
             )
     else:
     else:
-        if Users.get_num_users() != 0:
+        if has_users:
             raise HTTPException(
             raise HTTPException(
                 status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
                 status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
             )
             )
 
 
-    user_count = Users.get_num_users()
     if not validate_email_format(form_data.email.lower()):
     if not validate_email_format(form_data.email.lower()):
         raise HTTPException(
         raise HTTPException(
             status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
             status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
@@ -581,9 +579,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
         raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
         raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
 
 
     try:
     try:
-        role = (
-            "admin" if user_count == 0 else request.app.state.config.DEFAULT_USER_ROLE
-        )
+        role = "admin" if not has_users else request.app.state.config.DEFAULT_USER_ROLE
 
 
         # The password passed to bcrypt must be 72 bytes or fewer. If it is longer, it will be truncated before hashing.
         # The password passed to bcrypt must be 72 bytes or fewer. If it is longer, it will be truncated before hashing.
         if len(form_data.password.encode("utf-8")) > 72:
         if len(form_data.password.encode("utf-8")) > 72:
@@ -644,7 +640,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
                 user.id, request.app.state.config.USER_PERMISSIONS
                 user.id, request.app.state.config.USER_PERMISSIONS
             )
             )
 
 
-            if user_count == 0:
+            if not has_users:
                 # Disable signup after the first user is created
                 # Disable signup after the first user is created
                 request.app.state.config.ENABLE_SIGNUP = False
                 request.app.state.config.ENABLE_SIGNUP = False
 
 

+ 3 - 4
backend/open_webui/utils/oauth.py

@@ -88,11 +88,12 @@ class OAuthManager:
         return self.oauth.create_client(provider_name)
         return self.oauth.create_client(provider_name)
 
 
     def get_user_role(self, user, user_data):
     def get_user_role(self, user, user_data):
-        if user and Users.get_num_users() == 1:
+        user_count = Users.get_num_users()
+        if user and user_count == 1:
             # If the user is the only user, assign the role "admin" - actually repairs role for single user on login
             # If the user is the only user, assign the role "admin" - actually repairs role for single user on login
             log.debug("Assigning the only user the admin role")
             log.debug("Assigning the only user the admin role")
             return "admin"
             return "admin"
-        if not user and Users.get_num_users() == 0:
+        if not user and user_count == 0:
             # If there are no users, assign the role "admin", as the first user will be an admin
             # If there are no users, assign the role "admin", as the first user will be an admin
             log.debug("Assigning the first user the admin role")
             log.debug("Assigning the first user the admin role")
             return "admin"
             return "admin"
@@ -449,8 +450,6 @@ class OAuthManager:
                         log.debug(f"Updated profile picture for user {user.email}")
                         log.debug(f"Updated profile picture for user {user.email}")
 
 
         if not user:
         if not user:
-            user_count = Users.get_num_users()
-
             # If the user does not exist, check if signups are enabled
             # If the user does not exist, check if signups are enabled
             if auth_manager_config.ENABLE_OAUTH_SIGNUP:
             if auth_manager_config.ENABLE_OAUTH_SIGNUP:
                 # Check if an existing user with the same email already exists
                 # Check if an existing user with the same email already exists