|
@@ -4,6 +4,11 @@ import mimetypes
|
|
|
import sys
|
|
|
import uuid
|
|
|
import json
|
|
|
+from datetime import datetime, timedelta
|
|
|
+
|
|
|
+import re
|
|
|
+import fnmatch
|
|
|
+import time
|
|
|
|
|
|
import aiohttp
|
|
|
from authlib.integrations.starlette_client import OAuth
|
|
@@ -14,8 +19,12 @@ from fastapi import (
|
|
|
)
|
|
|
from starlette.responses import RedirectResponse
|
|
|
|
|
|
+
|
|
|
from open_webui.models.auths import Auths
|
|
|
+from open_webui.models.oauth_sessions import OAuthSessions
|
|
|
from open_webui.models.users import Users
|
|
|
+
|
|
|
+
|
|
|
from open_webui.models.groups import Groups, GroupModel, GroupUpdateForm, GroupForm
|
|
|
from open_webui.config import (
|
|
|
DEFAULT_USER_ROLE,
|
|
@@ -46,6 +55,7 @@ from open_webui.env import (
|
|
|
WEBUI_NAME,
|
|
|
WEBUI_AUTH_COOKIE_SAME_SITE,
|
|
|
WEBUI_AUTH_COOKIE_SECURE,
|
|
|
+ ENABLE_OAUTH_ID_TOKEN_COOKIE,
|
|
|
)
|
|
|
from open_webui.utils.misc import parse_duration
|
|
|
from open_webui.utils.auth import get_password_hash, create_token
|
|
@@ -79,15 +89,235 @@ auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
|
|
|
auth_manager_config.OAUTH_UPDATE_PICTURE_ON_LOGIN = OAUTH_UPDATE_PICTURE_ON_LOGIN
|
|
|
|
|
|
|
|
|
+def is_in_blocked_groups(group_name: str, groups: list) -> bool:
|
|
|
+ """
|
|
|
+ Check if a group name matches any blocked pattern.
|
|
|
+ Supports exact matches, shell-style wildcards (*, ?), and regex patterns.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ group_name: The group name to check
|
|
|
+ groups: List of patterns to match against
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ True if the group is blocked, False otherwise
|
|
|
+ """
|
|
|
+ if not groups:
|
|
|
+ return False
|
|
|
+
|
|
|
+ for group_pattern in groups:
|
|
|
+ if not group_pattern: # Skip empty patterns
|
|
|
+ continue
|
|
|
+
|
|
|
+ # Exact match
|
|
|
+ if group_name == group_pattern:
|
|
|
+ return True
|
|
|
+
|
|
|
+ # Try as regex pattern first if it contains regex-specific characters
|
|
|
+ if any(
|
|
|
+ char in group_pattern
|
|
|
+ for char in ["^", "$", "[", "]", "(", ")", "{", "}", "+", "\\", "|"]
|
|
|
+ ):
|
|
|
+ try:
|
|
|
+ # Use the original pattern as-is for regex matching
|
|
|
+ if re.search(group_pattern, group_name):
|
|
|
+ return True
|
|
|
+ except re.error:
|
|
|
+ # If regex is invalid, fall through to wildcard check
|
|
|
+ pass
|
|
|
+
|
|
|
+ # Shell-style wildcard match (supports * and ?)
|
|
|
+ if "*" in group_pattern or "?" in group_pattern:
|
|
|
+ if fnmatch.fnmatch(group_name, group_pattern):
|
|
|
+ return True
|
|
|
+
|
|
|
+ return False
|
|
|
+
|
|
|
+
|
|
|
class OAuthManager:
|
|
|
def __init__(self, app):
|
|
|
self.oauth = OAuth()
|
|
|
self.app = app
|
|
|
+
|
|
|
+ self._clients = {}
|
|
|
for _, provider_config in OAUTH_PROVIDERS.items():
|
|
|
provider_config["register"](self.oauth)
|
|
|
|
|
|
def get_client(self, provider_name):
|
|
|
- return self.oauth.create_client(provider_name)
|
|
|
+ if provider_name not in self._clients:
|
|
|
+ self._clients[provider_name] = self.oauth.create_client(provider_name)
|
|
|
+ return self._clients[provider_name]
|
|
|
+
|
|
|
+ def get_server_metadata_url(self, provider_name):
|
|
|
+ if provider_name in self._clients:
|
|
|
+ client = self._clients[provider_name]
|
|
|
+ return (
|
|
|
+ client.server_metadata_url
|
|
|
+ if hasattr(client, "server_metadata_url")
|
|
|
+ else None
|
|
|
+ )
|
|
|
+ return None
|
|
|
+
|
|
|
+ def get_oauth_token(
|
|
|
+ self, user_id: str, session_id: str, force_refresh: bool = False
|
|
|
+ ):
|
|
|
+ """
|
|
|
+ Get a valid OAuth token for the user, automatically refreshing if needed.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ user_id: The user ID
|
|
|
+ provider: Optional provider name. If None, gets the most recent session.
|
|
|
+ force_refresh: Force token refresh even if current token appears valid
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ dict: OAuth token data with access_token, or None if no valid token available
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # Get the OAuth session
|
|
|
+ session = OAuthSessions.get_session_by_id_and_user_id(session_id, user_id)
|
|
|
+ if not session:
|
|
|
+ log.warning(
|
|
|
+ f"No OAuth session found for user {user_id}, session {session_id}"
|
|
|
+ )
|
|
|
+ return None
|
|
|
+
|
|
|
+ if force_refresh or datetime.now() + timedelta(
|
|
|
+ minutes=5
|
|
|
+ ) >= datetime.fromtimestamp(session.expires_at):
|
|
|
+ log.debug(
|
|
|
+ f"Token refresh needed for user {user_id}, provider {session.provider}"
|
|
|
+ )
|
|
|
+ refreshed_token = self._refresh_token(session)
|
|
|
+ if refreshed_token:
|
|
|
+ return refreshed_token
|
|
|
+ else:
|
|
|
+ log.warning(
|
|
|
+ f"Token refresh failed for user {user_id}, provider {session.provider}"
|
|
|
+ )
|
|
|
+ return None
|
|
|
+ return session.token
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ log.error(f"Error getting OAuth token for user {user_id}: {e}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ async def _refresh_token(self, session) -> dict:
|
|
|
+ """
|
|
|
+ Refresh an OAuth token if needed, with concurrency protection.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ session: The OAuth session object
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ dict: Refreshed token data, or None if refresh failed
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # Perform the actual refresh
|
|
|
+ refreshed_token = await self._perform_token_refresh(session)
|
|
|
+
|
|
|
+ if refreshed_token:
|
|
|
+ # Update the session with new token data
|
|
|
+ session = OAuthSessions.update_session_by_id(
|
|
|
+ session.id, refreshed_token
|
|
|
+ )
|
|
|
+ log.info(f"Successfully refreshed token for session {session.id}")
|
|
|
+ return session.token
|
|
|
+ else:
|
|
|
+ log.error(f"Failed to refresh token for session {session.id}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ log.error(f"Error refreshing token for session {session.id}: {e}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ async def _perform_token_refresh(self, session) -> dict:
|
|
|
+ """
|
|
|
+ Perform the actual OAuth token refresh.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ session: The OAuth session object
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ dict: New token data, or None if refresh failed
|
|
|
+ """
|
|
|
+ provider = session.provider
|
|
|
+ token_data = session.token
|
|
|
+
|
|
|
+ if not token_data.get("refresh_token"):
|
|
|
+ log.warning(f"No refresh token available for session {session.id}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ try:
|
|
|
+ client = self.get_client(provider)
|
|
|
+ if not client:
|
|
|
+ log.error(f"No OAuth client found for provider {provider}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ token_endpoint = None
|
|
|
+ async with aiohttp.ClientSession(trust_env=True) as session_http:
|
|
|
+ async with session_http.get(client.gserver_metadata_url) as r:
|
|
|
+ if r.status == 200:
|
|
|
+ openid_data = await r.json()
|
|
|
+ token_endpoint = openid_data.get("token_endpoint")
|
|
|
+ else:
|
|
|
+ log.error(
|
|
|
+ f"Failed to fetch OpenID configuration for provider {provider}"
|
|
|
+ )
|
|
|
+ if not token_endpoint:
|
|
|
+ log.error(f"No token endpoint found for provider {provider}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ # Prepare refresh request
|
|
|
+ refresh_data = {
|
|
|
+ "grant_type": "refresh_token",
|
|
|
+ "refresh_token": token_data["refresh_token"],
|
|
|
+ "client_id": client.client_id,
|
|
|
+ }
|
|
|
+ # Add client_secret if available (some providers require it)
|
|
|
+ if hasattr(client, "client_secret") and client.client_secret:
|
|
|
+ refresh_data["client_secret"] = client.client_secret
|
|
|
+
|
|
|
+ # Make refresh request
|
|
|
+ async with aiohttp.ClientSession(trust_env=True) as session_http:
|
|
|
+ async with session_http.post(
|
|
|
+ token_endpoint,
|
|
|
+ data=refresh_data,
|
|
|
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|
|
+ ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
|
|
+ ) as r:
|
|
|
+ if r.status == 200:
|
|
|
+ new_token_data = await r.json()
|
|
|
+
|
|
|
+ # Merge with existing token data (preserve refresh_token if not provided)
|
|
|
+ if "refresh_token" not in new_token_data:
|
|
|
+ new_token_data["refresh_token"] = token_data[
|
|
|
+ "refresh_token"
|
|
|
+ ]
|
|
|
+
|
|
|
+ # Add timestamp for tracking
|
|
|
+ new_token_data["issued_at"] = datetime.now().timestamp()
|
|
|
+
|
|
|
+ # Calculate expires_at if we have expires_in
|
|
|
+ if (
|
|
|
+ "expires_in" in new_token_data
|
|
|
+ and "expires_at" not in new_token_data
|
|
|
+ ):
|
|
|
+ new_token_data["expires_at"] = (
|
|
|
+ datetime.now().timestamp()
|
|
|
+ + new_token_data["expires_in"]
|
|
|
+ )
|
|
|
+
|
|
|
+ log.debug(f"Token refresh successful for provider {provider}")
|
|
|
+ return new_token_data
|
|
|
+ else:
|
|
|
+ error_text = await r.text()
|
|
|
+ log.error(
|
|
|
+ f"Token refresh failed for provider {provider}: {r.status} - {error_text}"
|
|
|
+ )
|
|
|
+ return None
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ log.error(f"Exception during token refresh for provider {provider}: {e}")
|
|
|
+ return None
|
|
|
|
|
|
def get_user_role(self, user, user_data):
|
|
|
user_count = Users.get_num_users()
|
|
@@ -238,7 +468,7 @@ class OAuthManager:
|
|
|
if (
|
|
|
user_oauth_groups
|
|
|
and group_model.name not in user_oauth_groups
|
|
|
- and group_model.name not in blocked_groups
|
|
|
+ and not is_in_blocked_groups(group_model.name, blocked_groups)
|
|
|
):
|
|
|
# Remove group from user
|
|
|
log.debug(
|
|
@@ -269,7 +499,7 @@ class OAuthManager:
|
|
|
user_oauth_groups
|
|
|
and group_model.name in user_oauth_groups
|
|
|
and not any(gm.name == group_model.name for gm in user_current_groups)
|
|
|
- and group_model.name not in blocked_groups
|
|
|
+ and not is_in_blocked_groups(group_model.name, blocked_groups)
|
|
|
):
|
|
|
# Add user to group
|
|
|
log.debug(
|
|
@@ -354,185 +584,205 @@ class OAuthManager:
|
|
|
async def handle_callback(self, request, provider, response):
|
|
|
if provider not in OAUTH_PROVIDERS:
|
|
|
raise HTTPException(404)
|
|
|
- client = self.get_client(provider)
|
|
|
+
|
|
|
+ error_message = None
|
|
|
try:
|
|
|
- token = await client.authorize_access_token(request)
|
|
|
- except Exception as e:
|
|
|
- log.warning(f"OAuth callback error: {e}")
|
|
|
- raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
|
- user_data: UserInfo = token.get("userinfo")
|
|
|
- if (
|
|
|
- (not user_data)
|
|
|
- or (auth_manager_config.OAUTH_EMAIL_CLAIM not in user_data)
|
|
|
- or (auth_manager_config.OAUTH_USERNAME_CLAIM not in user_data)
|
|
|
- ):
|
|
|
- user_data: UserInfo = await client.userinfo(token=token)
|
|
|
- if not user_data:
|
|
|
- log.warning(f"OAuth callback failed, user data is missing: {token}")
|
|
|
- raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
|
+ client = self.get_client(provider)
|
|
|
+ try:
|
|
|
+ token = await client.authorize_access_token(request)
|
|
|
+ except Exception as e:
|
|
|
+ log.warning(f"OAuth callback error: {e}")
|
|
|
+ raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
|
|
|
|
- if auth_manager_config.OAUTH_SUB_CLAIM:
|
|
|
- sub = user_data.get(auth_manager_config.OAUTH_SUB_CLAIM)
|
|
|
- else:
|
|
|
- # Fallback to the default sub claim if not configured
|
|
|
- sub = user_data.get(OAUTH_PROVIDERS[provider].get("sub_claim", "sub"))
|
|
|
-
|
|
|
- if not sub:
|
|
|
- log.warning(f"OAuth callback failed, sub is missing: {user_data}")
|
|
|
- raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
|
-
|
|
|
- provider_sub = f"{provider}@{sub}"
|
|
|
-
|
|
|
- email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM
|
|
|
- email = user_data.get(email_claim, "")
|
|
|
- # We currently mandate that email addresses are provided
|
|
|
- if not email:
|
|
|
- # If the provider is GitHub,and public email is not provided, we can use the access token to fetch the user's email
|
|
|
- if provider == "github":
|
|
|
- try:
|
|
|
- access_token = token.get("access_token")
|
|
|
- headers = {"Authorization": f"Bearer {access_token}"}
|
|
|
- async with aiohttp.ClientSession(trust_env=True) as session:
|
|
|
- async with session.get(
|
|
|
- "https://api.github.com/user/emails",
|
|
|
- headers=headers,
|
|
|
- ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
|
|
- ) as resp:
|
|
|
- if resp.ok:
|
|
|
- emails = await resp.json()
|
|
|
- # use the primary email as the user's email
|
|
|
- primary_email = next(
|
|
|
- (e["email"] for e in emails if e.get("primary")),
|
|
|
- None,
|
|
|
- )
|
|
|
- if primary_email:
|
|
|
- email = primary_email
|
|
|
- else:
|
|
|
- log.warning(
|
|
|
- "No primary email found in GitHub response"
|
|
|
+ # Try to get userinfo from the token first, some providers include it there
|
|
|
+ user_data: UserInfo = token.get("userinfo")
|
|
|
+ if (
|
|
|
+ (not user_data)
|
|
|
+ or (auth_manager_config.OAUTH_EMAIL_CLAIM not in user_data)
|
|
|
+ or (auth_manager_config.OAUTH_USERNAME_CLAIM not in user_data)
|
|
|
+ ):
|
|
|
+ user_data: UserInfo = await client.userinfo(token=token)
|
|
|
+ if not user_data:
|
|
|
+ log.warning(f"OAuth callback failed, user data is missing: {token}")
|
|
|
+ raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
|
+
|
|
|
+ # Extract the "sub" claim, using custom claim if configured
|
|
|
+ if auth_manager_config.OAUTH_SUB_CLAIM:
|
|
|
+ sub = user_data.get(auth_manager_config.OAUTH_SUB_CLAIM)
|
|
|
+ else:
|
|
|
+ # Fallback to the default sub claim if not configured
|
|
|
+ sub = user_data.get(OAUTH_PROVIDERS[provider].get("sub_claim", "sub"))
|
|
|
+ if not sub:
|
|
|
+ log.warning(f"OAuth callback failed, sub is missing: {user_data}")
|
|
|
+ raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
|
+
|
|
|
+ provider_sub = f"{provider}@{sub}"
|
|
|
+
|
|
|
+ # Email extraction
|
|
|
+ email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM
|
|
|
+ email = user_data.get(email_claim, "")
|
|
|
+ # We currently mandate that email addresses are provided
|
|
|
+ if not email:
|
|
|
+ # If the provider is GitHub,and public email is not provided, we can use the access token to fetch the user's email
|
|
|
+ if provider == "github":
|
|
|
+ try:
|
|
|
+ access_token = token.get("access_token")
|
|
|
+ headers = {"Authorization": f"Bearer {access_token}"}
|
|
|
+ async with aiohttp.ClientSession(trust_env=True) as session:
|
|
|
+ async with session.get(
|
|
|
+ "https://api.github.com/user/emails",
|
|
|
+ headers=headers,
|
|
|
+ ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
|
|
+ ) as resp:
|
|
|
+ if resp.ok:
|
|
|
+ emails = await resp.json()
|
|
|
+ # use the primary email as the user's email
|
|
|
+ primary_email = next(
|
|
|
+ (
|
|
|
+ e["email"]
|
|
|
+ for e in emails
|
|
|
+ if e.get("primary")
|
|
|
+ ),
|
|
|
+ None,
|
|
|
)
|
|
|
+ if primary_email:
|
|
|
+ email = primary_email
|
|
|
+ else:
|
|
|
+ log.warning(
|
|
|
+ "No primary email found in GitHub response"
|
|
|
+ )
|
|
|
+ raise HTTPException(
|
|
|
+ 400, detail=ERROR_MESSAGES.INVALID_CRED
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ log.warning("Failed to fetch GitHub email")
|
|
|
raise HTTPException(
|
|
|
400, detail=ERROR_MESSAGES.INVALID_CRED
|
|
|
)
|
|
|
- else:
|
|
|
- log.warning("Failed to fetch GitHub email")
|
|
|
- raise HTTPException(
|
|
|
- 400, detail=ERROR_MESSAGES.INVALID_CRED
|
|
|
- )
|
|
|
- except Exception as e:
|
|
|
- log.warning(f"Error fetching GitHub email: {e}")
|
|
|
+ except Exception as e:
|
|
|
+ log.warning(f"Error fetching GitHub email: {e}")
|
|
|
+ raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
|
+ else:
|
|
|
+ log.warning(f"OAuth callback failed, email is missing: {user_data}")
|
|
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
|
- else:
|
|
|
- log.warning(f"OAuth callback failed, email is missing: {user_data}")
|
|
|
+ email = email.lower()
|
|
|
+
|
|
|
+ # If allowed domains are configured, check if the email domain is in the list
|
|
|
+ if (
|
|
|
+ "*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
|
|
|
+ and email.split("@")[-1]
|
|
|
+ not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
|
|
|
+ ):
|
|
|
+ log.warning(
|
|
|
+ f"OAuth callback failed, e-mail domain is not in the list of allowed domains: {user_data}"
|
|
|
+ )
|
|
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
|
- email = email.lower()
|
|
|
- if (
|
|
|
- "*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
|
|
|
- and email.split("@")[-1] not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
|
|
|
- ):
|
|
|
- log.warning(
|
|
|
- f"OAuth callback failed, e-mail domain is not in the list of allowed domains: {user_data}"
|
|
|
- )
|
|
|
- raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
|
-
|
|
|
- # Check if the user exists
|
|
|
- user = Users.get_user_by_oauth_sub(provider_sub)
|
|
|
-
|
|
|
- if not user:
|
|
|
- # If the user does not exist, check if merging is enabled
|
|
|
- if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL:
|
|
|
- # Check if the user exists by email
|
|
|
- user = Users.get_user_by_email(email)
|
|
|
- if user:
|
|
|
- # Update the user with the new oauth sub
|
|
|
- Users.update_user_oauth_sub_by_id(user.id, provider_sub)
|
|
|
-
|
|
|
- if user:
|
|
|
- determined_role = self.get_user_role(user, user_data)
|
|
|
- if user.role != determined_role:
|
|
|
- Users.update_user_role_by_id(user.id, determined_role)
|
|
|
-
|
|
|
- # Update profile picture if enabled and different from current
|
|
|
- if auth_manager_config.OAUTH_UPDATE_PICTURE_ON_LOGIN:
|
|
|
- picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM
|
|
|
- if picture_claim:
|
|
|
- new_picture_url = user_data.get(
|
|
|
- picture_claim, OAUTH_PROVIDERS[provider].get("picture_url", "")
|
|
|
- )
|
|
|
- processed_picture_url = await self._process_picture_url(
|
|
|
- new_picture_url, token.get("access_token")
|
|
|
- )
|
|
|
- if processed_picture_url != user.profile_image_url:
|
|
|
- Users.update_user_profile_image_url_by_id(
|
|
|
- user.id, processed_picture_url
|
|
|
+
|
|
|
+ # Check if the user exists
|
|
|
+ user = Users.get_user_by_oauth_sub(provider_sub)
|
|
|
+ if not user:
|
|
|
+ # If the user does not exist, check if merging is enabled
|
|
|
+ if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL:
|
|
|
+ # Check if the user exists by email
|
|
|
+ user = Users.get_user_by_email(email)
|
|
|
+ if user:
|
|
|
+ # Update the user with the new oauth sub
|
|
|
+ Users.update_user_oauth_sub_by_id(user.id, provider_sub)
|
|
|
+
|
|
|
+ if user:
|
|
|
+ determined_role = self.get_user_role(user, user_data)
|
|
|
+ if user.role != determined_role:
|
|
|
+ Users.update_user_role_by_id(user.id, determined_role)
|
|
|
+ # Update profile picture if enabled and different from current
|
|
|
+ if auth_manager_config.OAUTH_UPDATE_PICTURE_ON_LOGIN:
|
|
|
+ picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM
|
|
|
+ if picture_claim:
|
|
|
+ new_picture_url = user_data.get(
|
|
|
+ picture_claim,
|
|
|
+ OAUTH_PROVIDERS[provider].get("picture_url", ""),
|
|
|
)
|
|
|
- log.debug(f"Updated profile picture for user {user.email}")
|
|
|
-
|
|
|
- if not user:
|
|
|
- # If the user does not exist, check if signups are enabled
|
|
|
- if auth_manager_config.ENABLE_OAUTH_SIGNUP:
|
|
|
- # Check if an existing user with the same email already exists
|
|
|
- existing_user = Users.get_user_by_email(email)
|
|
|
- if existing_user:
|
|
|
- raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
|
|
-
|
|
|
- picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM
|
|
|
- if picture_claim:
|
|
|
- picture_url = user_data.get(
|
|
|
- picture_claim, OAUTH_PROVIDERS[provider].get("picture_url", "")
|
|
|
- )
|
|
|
- picture_url = await self._process_picture_url(
|
|
|
- picture_url, token.get("access_token")
|
|
|
+ processed_picture_url = await self._process_picture_url(
|
|
|
+ new_picture_url, token.get("access_token")
|
|
|
+ )
|
|
|
+ if processed_picture_url != user.profile_image_url:
|
|
|
+ Users.update_user_profile_image_url_by_id(
|
|
|
+ user.id, processed_picture_url
|
|
|
+ )
|
|
|
+ log.debug(f"Updated profile picture for user {user.email}")
|
|
|
+ else:
|
|
|
+ # If the user does not exist, check if signups are enabled
|
|
|
+ if auth_manager_config.ENABLE_OAUTH_SIGNUP:
|
|
|
+ # Check if an existing user with the same email already exists
|
|
|
+ existing_user = Users.get_user_by_email(email)
|
|
|
+ if existing_user:
|
|
|
+ raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
|
|
+
|
|
|
+ picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM
|
|
|
+ if picture_claim:
|
|
|
+ picture_url = user_data.get(
|
|
|
+ picture_claim,
|
|
|
+ OAUTH_PROVIDERS[provider].get("picture_url", ""),
|
|
|
+ )
|
|
|
+ picture_url = await self._process_picture_url(
|
|
|
+ picture_url, token.get("access_token")
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ picture_url = "/user.png"
|
|
|
+ username_claim = auth_manager_config.OAUTH_USERNAME_CLAIM
|
|
|
+
|
|
|
+ name = user_data.get(username_claim)
|
|
|
+ if not name:
|
|
|
+ log.warning("Username claim is missing, using email as name")
|
|
|
+ name = email
|
|
|
+
|
|
|
+ user = Auths.insert_new_auth(
|
|
|
+ email=email,
|
|
|
+ password=get_password_hash(
|
|
|
+ str(uuid.uuid4())
|
|
|
+ ), # Random password, not used
|
|
|
+ name=name,
|
|
|
+ profile_image_url=picture_url,
|
|
|
+ role=self.get_user_role(None, user_data),
|
|
|
+ oauth_sub=provider_sub,
|
|
|
)
|
|
|
- else:
|
|
|
- picture_url = "/user.png"
|
|
|
-
|
|
|
- username_claim = auth_manager_config.OAUTH_USERNAME_CLAIM
|
|
|
-
|
|
|
- name = user_data.get(username_claim)
|
|
|
- if not name:
|
|
|
- log.warning("Username claim is missing, using email as name")
|
|
|
- name = email
|
|
|
-
|
|
|
- role = self.get_user_role(None, user_data)
|
|
|
-
|
|
|
- user = Auths.insert_new_auth(
|
|
|
- email=email,
|
|
|
- password=get_password_hash(
|
|
|
- str(uuid.uuid4())
|
|
|
- ), # Random password, not used
|
|
|
- name=name,
|
|
|
- profile_image_url=picture_url,
|
|
|
- role=role,
|
|
|
- oauth_sub=provider_sub,
|
|
|
- )
|
|
|
|
|
|
- if auth_manager_config.WEBHOOK_URL:
|
|
|
- await post_webhook(
|
|
|
- WEBUI_NAME,
|
|
|
- auth_manager_config.WEBHOOK_URL,
|
|
|
- WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
|
|
- {
|
|
|
- "action": "signup",
|
|
|
- "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
|
|
- "user": user.model_dump_json(exclude_none=True),
|
|
|
- },
|
|
|
+ if auth_manager_config.WEBHOOK_URL:
|
|
|
+ await post_webhook(
|
|
|
+ WEBUI_NAME,
|
|
|
+ auth_manager_config.WEBHOOK_URL,
|
|
|
+ WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
|
|
+ {
|
|
|
+ "action": "signup",
|
|
|
+ "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
|
|
+ "user": user.model_dump_json(exclude_none=True),
|
|
|
+ },
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ raise HTTPException(
|
|
|
+ status.HTTP_403_FORBIDDEN,
|
|
|
+ detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
|
|
)
|
|
|
- else:
|
|
|
- raise HTTPException(
|
|
|
- status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
|
|
|
- )
|
|
|
|
|
|
- jwt_token = create_token(
|
|
|
- data={"id": user.id},
|
|
|
- expires_delta=parse_duration(auth_manager_config.JWT_EXPIRES_IN),
|
|
|
- )
|
|
|
+ jwt_token = create_token(
|
|
|
+ data={"id": user.id},
|
|
|
+ expires_delta=parse_duration(auth_manager_config.JWT_EXPIRES_IN),
|
|
|
+ )
|
|
|
+ if (
|
|
|
+ auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT
|
|
|
+ and user.role != "admin"
|
|
|
+ ):
|
|
|
+ self.update_user_groups(
|
|
|
+ user=user,
|
|
|
+ user_data=user_data,
|
|
|
+ default_permissions=request.app.state.config.USER_PERMISSIONS,
|
|
|
+ )
|
|
|
|
|
|
- if auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT and user.role != "admin":
|
|
|
- self.update_user_groups(
|
|
|
- user=user,
|
|
|
- user_data=user_data,
|
|
|
- default_permissions=request.app.state.config.USER_PERMISSIONS,
|
|
|
+ except Exception as e:
|
|
|
+ log.error(f"Error during OAuth process: {e}")
|
|
|
+ error_message = (
|
|
|
+ e.detail
|
|
|
+ if isinstance(e, HTTPException) and e.detail
|
|
|
+ else ERROR_MESSAGES.DEFAULT("Error during OAuth process")
|
|
|
)
|
|
|
|
|
|
redirect_base_url = str(request.app.state.config.WEBUI_URL or request.base_url)
|
|
@@ -540,6 +790,10 @@ class OAuthManager:
|
|
|
redirect_base_url = redirect_base_url[:-1]
|
|
|
redirect_url = f"{redirect_base_url}/auth"
|
|
|
|
|
|
+ if error_message:
|
|
|
+ redirect_url = f"{redirect_url}?error={error_message}"
|
|
|
+ return RedirectResponse(url=redirect_url, headers=response.headers)
|
|
|
+
|
|
|
response = RedirectResponse(url=redirect_url, headers=response.headers)
|
|
|
|
|
|
# Set the cookie token
|
|
@@ -552,13 +806,48 @@ class OAuthManager:
|
|
|
secure=WEBUI_AUTH_COOKIE_SECURE,
|
|
|
)
|
|
|
|
|
|
- if ENABLE_OAUTH_SIGNUP.value:
|
|
|
- oauth_id_token = token.get("id_token")
|
|
|
+ # Legacy cookies for compatibility with older frontend versions
|
|
|
+ if ENABLE_OAUTH_ID_TOKEN_COOKIE:
|
|
|
response.set_cookie(
|
|
|
key="oauth_id_token",
|
|
|
- value=oauth_id_token,
|
|
|
+ value=token.get("id_token"),
|
|
|
httponly=True,
|
|
|
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
|
|
secure=WEBUI_AUTH_COOKIE_SECURE,
|
|
|
)
|
|
|
+
|
|
|
+ try:
|
|
|
+ # Add timestamp for tracking
|
|
|
+ token["issued_at"] = datetime.now().timestamp()
|
|
|
+
|
|
|
+ # Calculate expires_at if we have expires_in
|
|
|
+ if "expires_in" in token and "expires_at" not in token:
|
|
|
+ token["expires_at"] = datetime.now().timestamp() + token["expires_in"]
|
|
|
+
|
|
|
+ # Clean up any existing sessions for this user/provider first
|
|
|
+ sessions = OAuthSessions.get_sessions_by_user_id(user.id)
|
|
|
+ for session in sessions:
|
|
|
+ if session.provider == provider:
|
|
|
+ OAuthSessions.delete_session_by_id(session.id)
|
|
|
+
|
|
|
+ session = OAuthSessions.create_session(
|
|
|
+ user_id=user.id,
|
|
|
+ provider=provider,
|
|
|
+ token=token,
|
|
|
+ )
|
|
|
+
|
|
|
+ response.set_cookie(
|
|
|
+ key="oauth_session_id",
|
|
|
+ value=session.id,
|
|
|
+ httponly=True,
|
|
|
+ samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
|
|
+ secure=WEBUI_AUTH_COOKIE_SECURE,
|
|
|
+ )
|
|
|
+
|
|
|
+ log.info(
|
|
|
+ f"Stored OAuth session server-side for user {user.id}, provider {provider}"
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ log.error(f"Failed to store OAuth session server-side: {e}")
|
|
|
+
|
|
|
return response
|