فهرست منبع

refac/enh: display oauth error as toast

Timothy Jaeryang Baek 1 ماه پیش
والد
کامیت
3d6d050ad8
2فایلهای تغییر یافته به همراه197 افزوده شده و 166 حذف شده
  1. 191 165
      backend/open_webui/utils/oauth.py
  2. 6 1
      src/routes/auth/+page.svelte

+ 191 - 165
backend/open_webui/utils/oauth.py

@@ -401,185 +401,207 @@ class OAuthManager:
     async def handle_callback(self, request, provider, response):
         if provider not in OAUTH_PROVIDERS:
             raise HTTPException(404)
-        client = self.get_client(provider)
+
+        error_message = None
         try:
-            token = await client.authorize_access_token(request)
-        except Exception as e:
-            log.warning(f"OAuth callback error: {e}")
-            raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
-        user_data: UserInfo = token.get("userinfo")
-        if (
-            (not user_data)
-            or (auth_manager_config.OAUTH_EMAIL_CLAIM not in user_data)
-            or (auth_manager_config.OAUTH_USERNAME_CLAIM not in user_data)
-        ):
-            user_data: UserInfo = await client.userinfo(token=token)
-        if not user_data:
-            log.warning(f"OAuth callback failed, user data is missing: {token}")
-            raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
+            client = self.get_client(provider)
+            try:
+                token = await client.authorize_access_token(request)
+            except Exception as e:
+                log.warning(f"OAuth callback error: {e}")
+                raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
+            user_data: UserInfo = token.get("userinfo")
+            if (
+                (not user_data)
+                or (auth_manager_config.OAUTH_EMAIL_CLAIM not in user_data)
+                or (auth_manager_config.OAUTH_USERNAME_CLAIM not in user_data)
+            ):
+                user_data: UserInfo = await client.userinfo(token=token)
+            if not user_data:
+                log.warning(f"OAuth callback failed, user data is missing: {token}")
+                raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
 
-        if auth_manager_config.OAUTH_SUB_CLAIM:
-            sub = user_data.get(auth_manager_config.OAUTH_SUB_CLAIM)
-        else:
-            # Fallback to the default sub claim if not configured
-            sub = user_data.get(OAUTH_PROVIDERS[provider].get("sub_claim", "sub"))
-
-        if not sub:
-            log.warning(f"OAuth callback failed, sub is missing: {user_data}")
-            raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
-
-        provider_sub = f"{provider}@{sub}"
-
-        email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM
-        email = user_data.get(email_claim, "")
-        # We currently mandate that email addresses are provided
-        if not email:
-            # If the provider is GitHub,and public email is not provided, we can use the access token to fetch the user's email
-            if provider == "github":
-                try:
-                    access_token = token.get("access_token")
-                    headers = {"Authorization": f"Bearer {access_token}"}
-                    async with aiohttp.ClientSession(trust_env=True) as session:
-                        async with session.get(
-                            "https://api.github.com/user/emails",
-                            headers=headers,
-                            ssl=AIOHTTP_CLIENT_SESSION_SSL,
-                        ) as resp:
-                            if resp.ok:
-                                emails = await resp.json()
-                                # use the primary email as the user's email
-                                primary_email = next(
-                                    (e["email"] for e in emails if e.get("primary")),
-                                    None,
-                                )
-                                if primary_email:
-                                    email = primary_email
-                                else:
-                                    log.warning(
-                                        "No primary email found in GitHub response"
+            if auth_manager_config.OAUTH_SUB_CLAIM:
+                sub = user_data.get(auth_manager_config.OAUTH_SUB_CLAIM)
+            else:
+                # Fallback to the default sub claim if not configured
+                sub = user_data.get(OAUTH_PROVIDERS[provider].get("sub_claim", "sub"))
+
+            if not sub:
+                log.warning(f"OAuth callback failed, sub is missing: {user_data}")
+                raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
+
+            provider_sub = f"{provider}@{sub}"
+
+            email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM
+            email = user_data.get(email_claim, "")
+            # We currently mandate that email addresses are provided
+            if not email:
+                # If the provider is GitHub,and public email is not provided, we can use the access token to fetch the user's email
+                if provider == "github":
+                    try:
+                        access_token = token.get("access_token")
+                        headers = {"Authorization": f"Bearer {access_token}"}
+                        async with aiohttp.ClientSession(trust_env=True) as session:
+                            async with session.get(
+                                "https://api.github.com/user/emails",
+                                headers=headers,
+                                ssl=AIOHTTP_CLIENT_SESSION_SSL,
+                            ) as resp:
+                                if resp.ok:
+                                    emails = await resp.json()
+                                    # use the primary email as the user's email
+                                    primary_email = next(
+                                        (
+                                            e["email"]
+                                            for e in emails
+                                            if e.get("primary")
+                                        ),
+                                        None,
                                     )
+                                    if primary_email:
+                                        email = primary_email
+                                    else:
+                                        log.warning(
+                                            "No primary email found in GitHub response"
+                                        )
+                                        raise HTTPException(
+                                            400, detail=ERROR_MESSAGES.INVALID_CRED
+                                        )
+                                else:
+                                    log.warning("Failed to fetch GitHub email")
                                     raise HTTPException(
                                         400, detail=ERROR_MESSAGES.INVALID_CRED
                                     )
-                            else:
-                                log.warning("Failed to fetch GitHub email")
-                                raise HTTPException(
-                                    400, detail=ERROR_MESSAGES.INVALID_CRED
-                                )
-                except Exception as e:
-                    log.warning(f"Error fetching GitHub email: {e}")
+                    except Exception as e:
+                        log.warning(f"Error fetching GitHub email: {e}")
+                        raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
+                else:
+                    log.warning(f"OAuth callback failed, email is missing: {user_data}")
                     raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
-            else:
-                log.warning(f"OAuth callback failed, email is missing: {user_data}")
+            email = email.lower()
+            if (
+                "*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
+                and email.split("@")[-1]
+                not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
+            ):
+                log.warning(
+                    f"OAuth callback failed, e-mail domain is not in the list of allowed domains: {user_data}"
+                )
                 raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
-        email = email.lower()
-        if (
-            "*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
-            and email.split("@")[-1] not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
-        ):
-            log.warning(
-                f"OAuth callback failed, e-mail domain is not in the list of allowed domains: {user_data}"
-            )
-            raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
-
-        # Check if the user exists
-        user = Users.get_user_by_oauth_sub(provider_sub)
-
-        if not user:
-            # If the user does not exist, check if merging is enabled
-            if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL:
-                # Check if the user exists by email
-                user = Users.get_user_by_email(email)
-                if user:
-                    # Update the user with the new oauth sub
-                    Users.update_user_oauth_sub_by_id(user.id, provider_sub)
-
-        if user:
-            determined_role = self.get_user_role(user, user_data)
-            if user.role != determined_role:
-                Users.update_user_role_by_id(user.id, determined_role)
-
-            # Update profile picture if enabled and different from current
-            if auth_manager_config.OAUTH_UPDATE_PICTURE_ON_LOGIN:
-                picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM
-                if picture_claim:
-                    new_picture_url = user_data.get(
-                        picture_claim, OAUTH_PROVIDERS[provider].get("picture_url", "")
-                    )
-                    processed_picture_url = await self._process_picture_url(
-                        new_picture_url, token.get("access_token")
-                    )
-                    if processed_picture_url != user.profile_image_url:
-                        Users.update_user_profile_image_url_by_id(
-                            user.id, processed_picture_url
+
+            # Check if the user exists
+            user = Users.get_user_by_oauth_sub(provider_sub)
+
+            if not user:
+                # If the user does not exist, check if merging is enabled
+                if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL:
+                    # Check if the user exists by email
+                    user = Users.get_user_by_email(email)
+                    if user:
+                        # Update the user with the new oauth sub
+                        Users.update_user_oauth_sub_by_id(user.id, provider_sub)
+
+            if user:
+                determined_role = self.get_user_role(user, user_data)
+                if user.role != determined_role:
+                    Users.update_user_role_by_id(user.id, determined_role)
+
+                # Update profile picture if enabled and different from current
+                if auth_manager_config.OAUTH_UPDATE_PICTURE_ON_LOGIN:
+                    picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM
+                    if picture_claim:
+                        new_picture_url = user_data.get(
+                            picture_claim,
+                            OAUTH_PROVIDERS[provider].get("picture_url", ""),
                         )
-                        log.debug(f"Updated profile picture for user {user.email}")
-
-        if not user:
-            # If the user does not exist, check if signups are enabled
-            if auth_manager_config.ENABLE_OAUTH_SIGNUP:
-                # Check if an existing user with the same email already exists
-                existing_user = Users.get_user_by_email(email)
-                if existing_user:
-                    raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
-
-                picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM
-                if picture_claim:
-                    picture_url = user_data.get(
-                        picture_claim, OAUTH_PROVIDERS[provider].get("picture_url", "")
-                    )
-                    picture_url = await self._process_picture_url(
-                        picture_url, token.get("access_token")
+                        processed_picture_url = await self._process_picture_url(
+                            new_picture_url, token.get("access_token")
+                        )
+                        if processed_picture_url != user.profile_image_url:
+                            Users.update_user_profile_image_url_by_id(
+                                user.id, processed_picture_url
+                            )
+                            log.debug(f"Updated profile picture for user {user.email}")
+
+            if not user:
+                # If the user does not exist, check if signups are enabled
+                if auth_manager_config.ENABLE_OAUTH_SIGNUP:
+                    # Check if an existing user with the same email already exists
+                    existing_user = Users.get_user_by_email(email)
+                    if existing_user:
+                        raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
+
+                    picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM
+                    if picture_claim:
+                        picture_url = user_data.get(
+                            picture_claim,
+                            OAUTH_PROVIDERS[provider].get("picture_url", ""),
+                        )
+                        picture_url = await self._process_picture_url(
+                            picture_url, token.get("access_token")
+                        )
+                    else:
+                        picture_url = "/user.png"
+
+                    username_claim = auth_manager_config.OAUTH_USERNAME_CLAIM
+
+                    name = user_data.get(username_claim)
+                    if not name:
+                        log.warning("Username claim is missing, using email as name")
+                        name = email
+
+                    role = self.get_user_role(None, user_data)
+
+                    user = Auths.insert_new_auth(
+                        email=email,
+                        password=get_password_hash(
+                            str(uuid.uuid4())
+                        ),  # Random password, not used
+                        name=name,
+                        profile_image_url=picture_url,
+                        role=role,
+                        oauth_sub=provider_sub,
                     )
-                else:
-                    picture_url = "/user.png"
-
-                username_claim = auth_manager_config.OAUTH_USERNAME_CLAIM
-
-                name = user_data.get(username_claim)
-                if not name:
-                    log.warning("Username claim is missing, using email as name")
-                    name = email
-
-                role = self.get_user_role(None, user_data)
-
-                user = Auths.insert_new_auth(
-                    email=email,
-                    password=get_password_hash(
-                        str(uuid.uuid4())
-                    ),  # Random password, not used
-                    name=name,
-                    profile_image_url=picture_url,
-                    role=role,
-                    oauth_sub=provider_sub,
-                )
 
-                if auth_manager_config.WEBHOOK_URL:
-                    await post_webhook(
-                        WEBUI_NAME,
-                        auth_manager_config.WEBHOOK_URL,
-                        WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
-                        {
-                            "action": "signup",
-                            "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
-                            "user": user.model_dump_json(exclude_none=True),
-                        },
+                    if auth_manager_config.WEBHOOK_URL:
+                        await post_webhook(
+                            WEBUI_NAME,
+                            auth_manager_config.WEBHOOK_URL,
+                            WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
+                            {
+                                "action": "signup",
+                                "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
+                                "user": user.model_dump_json(exclude_none=True),
+                            },
+                        )
+                else:
+                    raise HTTPException(
+                        status.HTTP_403_FORBIDDEN,
+                        detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
                     )
-            else:
-                raise HTTPException(
-                    status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
-                )
 
-        jwt_token = create_token(
-            data={"id": user.id},
-            expires_delta=parse_duration(auth_manager_config.JWT_EXPIRES_IN),
-        )
+            jwt_token = create_token(
+                data={"id": user.id},
+                expires_delta=parse_duration(auth_manager_config.JWT_EXPIRES_IN),
+            )
 
-        if auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT and user.role != "admin":
-            self.update_user_groups(
-                user=user,
-                user_data=user_data,
-                default_permissions=request.app.state.config.USER_PERMISSIONS,
+            if (
+                auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT
+                and user.role != "admin"
+            ):
+                self.update_user_groups(
+                    user=user,
+                    user_data=user_data,
+                    default_permissions=request.app.state.config.USER_PERMISSIONS,
+                )
+
+        except Exception as e:
+            log.error(f"Error during OAuth process: {e}")
+            error_message = (
+                e.detail
+                if isinstance(e, HTTPException) and e.detail
+                else ERROR_MESSAGES.DEFAULT("Error during OAuth process")
             )
 
         redirect_base_url = str(request.app.state.config.WEBUI_URL or request.base_url)
@@ -587,6 +609,10 @@ class OAuthManager:
             redirect_base_url = redirect_base_url[:-1]
         redirect_url = f"{redirect_base_url}/auth"
 
+        if error_message:
+            redirect_url = f"{redirect_url}?error={error_message}"
+            return RedirectResponse(url=redirect_url, headers=response.headers)
+
         response = RedirectResponse(url=redirect_url, headers=response.headers)
 
         # Set the cookie token

+ 6 - 1
src/routes/auth/+page.svelte

@@ -162,8 +162,13 @@
 				localStorage.setItem('redirectPath', redirectPath);
 			}
 		}
-		await oauthCallbackHandler();
 
+		const error = $page.url.searchParams.get('error');
+		if (error) {
+			toast.error(error);
+		}
+
+		await oauthCallbackHandler();
 		form = $page.url.searchParams.get('form');
 
 		loaded = true;