123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344 |
- import base64
- import hashlib
- import logging
- import mimetypes
- import sys
- import urllib
- import uuid
- import json
- from datetime import datetime, timedelta
- import re
- import fnmatch
- import time
- import secrets
- from cryptography.fernet import Fernet
- import aiohttp
- from authlib.integrations.starlette_client import OAuth
- from authlib.oidc.core import UserInfo
- from fastapi import (
- HTTPException,
- status,
- )
- from starlette.responses import RedirectResponse
- from typing import Optional
- from open_webui.models.auths import Auths
- from open_webui.models.oauth_sessions import OAuthSessions
- from open_webui.models.users import Users
- from open_webui.models.groups import Groups, GroupModel, GroupUpdateForm, GroupForm
- from open_webui.config import (
- DEFAULT_USER_ROLE,
- ENABLE_OAUTH_SIGNUP,
- OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
- OAUTH_PROVIDERS,
- ENABLE_OAUTH_ROLE_MANAGEMENT,
- ENABLE_OAUTH_GROUP_MANAGEMENT,
- ENABLE_OAUTH_GROUP_CREATION,
- OAUTH_BLOCKED_GROUPS,
- OAUTH_ROLES_CLAIM,
- OAUTH_SUB_CLAIM,
- OAUTH_GROUPS_CLAIM,
- OAUTH_EMAIL_CLAIM,
- OAUTH_PICTURE_CLAIM,
- OAUTH_USERNAME_CLAIM,
- OAUTH_ALLOWED_ROLES,
- OAUTH_ADMIN_ROLES,
- OAUTH_ALLOWED_DOMAINS,
- OAUTH_UPDATE_PICTURE_ON_LOGIN,
- WEBHOOK_URL,
- JWT_EXPIRES_IN,
- AppConfig,
- )
- from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
- from open_webui.env import (
- AIOHTTP_CLIENT_SESSION_SSL,
- WEBUI_NAME,
- WEBUI_AUTH_COOKIE_SAME_SITE,
- WEBUI_AUTH_COOKIE_SECURE,
- ENABLE_OAUTH_ID_TOKEN_COOKIE,
- OAUTH_CLIENT_INFO_ENCRYPTION_KEY,
- )
- from open_webui.utils.misc import parse_duration
- from open_webui.utils.auth import get_password_hash, create_token
- from open_webui.utils.webhook import post_webhook
- from mcp.shared.auth import (
- OAuthClientMetadata,
- OAuthMetadata,
- )
- class OAuthClientInformationFull(OAuthClientMetadata):
- issuer: Optional[str] = None # URL of the OAuth server that issued this client
- client_id: str
- client_secret: str | None = None
- client_id_issued_at: int | None = None
- client_secret_expires_at: int | None = None
- from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
- logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
- log = logging.getLogger(__name__)
- log.setLevel(SRC_LOG_LEVELS["OAUTH"])
- auth_manager_config = AppConfig()
- auth_manager_config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
- auth_manager_config.ENABLE_OAUTH_SIGNUP = ENABLE_OAUTH_SIGNUP
- auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL = OAUTH_MERGE_ACCOUNTS_BY_EMAIL
- auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT
- auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT = ENABLE_OAUTH_GROUP_MANAGEMENT
- auth_manager_config.ENABLE_OAUTH_GROUP_CREATION = ENABLE_OAUTH_GROUP_CREATION
- auth_manager_config.OAUTH_BLOCKED_GROUPS = OAUTH_BLOCKED_GROUPS
- auth_manager_config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
- auth_manager_config.OAUTH_SUB_CLAIM = OAUTH_SUB_CLAIM
- auth_manager_config.OAUTH_GROUPS_CLAIM = OAUTH_GROUPS_CLAIM
- auth_manager_config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
- auth_manager_config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
- auth_manager_config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
- auth_manager_config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES
- auth_manager_config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES
- auth_manager_config.OAUTH_ALLOWED_DOMAINS = OAUTH_ALLOWED_DOMAINS
- auth_manager_config.WEBHOOK_URL = WEBHOOK_URL
- auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
- auth_manager_config.OAUTH_UPDATE_PICTURE_ON_LOGIN = OAUTH_UPDATE_PICTURE_ON_LOGIN
- FERNET = None
- if len(OAUTH_CLIENT_INFO_ENCRYPTION_KEY) != 44:
- key_bytes = hashlib.sha256(OAUTH_CLIENT_INFO_ENCRYPTION_KEY.encode()).digest()
- OAUTH_CLIENT_INFO_ENCRYPTION_KEY = base64.urlsafe_b64encode(key_bytes)
- else:
- OAUTH_CLIENT_INFO_ENCRYPTION_KEY = OAUTH_CLIENT_INFO_ENCRYPTION_KEY.encode()
- try:
- FERNET = Fernet(OAUTH_CLIENT_INFO_ENCRYPTION_KEY)
- except Exception as e:
- log.error(f"Error initializing Fernet with provided key: {e}")
- raise
- def encrypt_data(data) -> str:
- """Encrypt data for storage"""
- try:
- data_json = json.dumps(data)
- encrypted = FERNET.encrypt(data_json.encode()).decode()
- return encrypted
- except Exception as e:
- log.error(f"Error encrypting data: {e}")
- raise
- def decrypt_data(data: str):
- """Decrypt data from storage"""
- try:
- decrypted = FERNET.decrypt(data.encode()).decode()
- return json.loads(decrypted)
- except Exception as e:
- log.error(f"Error decrypting data: {e}")
- raise
- def is_in_blocked_groups(group_name: str, groups: list) -> bool:
- """
- Check if a group name matches any blocked pattern.
- Supports exact matches, shell-style wildcards (*, ?), and regex patterns.
- Args:
- group_name: The group name to check
- groups: List of patterns to match against
- Returns:
- True if the group is blocked, False otherwise
- """
- if not groups:
- return False
- for group_pattern in groups:
- if not group_pattern: # Skip empty patterns
- continue
- # Exact match
- if group_name == group_pattern:
- return True
- # Try as regex pattern first if it contains regex-specific characters
- if any(
- char in group_pattern
- for char in ["^", "$", "[", "]", "(", ")", "{", "}", "+", "\\", "|"]
- ):
- try:
- # Use the original pattern as-is for regex matching
- if re.search(group_pattern, group_name):
- return True
- except re.error:
- # If regex is invalid, fall through to wildcard check
- pass
- # Shell-style wildcard match (supports * and ?)
- if "*" in group_pattern or "?" in group_pattern:
- if fnmatch.fnmatch(group_name, group_pattern):
- return True
- return False
- def get_parsed_and_base_url(server_url) -> tuple[urllib.parse.ParseResult, str]:
- parsed = urllib.parse.urlparse(server_url)
- base_url = f"{parsed.scheme}://{parsed.netloc}"
- return parsed, base_url
- def get_discovery_urls(server_url) -> list[str]:
- parsed, base_url = get_parsed_and_base_url(server_url)
- urls = [
- urllib.parse.urljoin(base_url, "/.well-known/oauth-authorization-server"),
- urllib.parse.urljoin(base_url, "/.well-known/openid-configuration"),
- ]
- if parsed.path and parsed.path != "/":
- urls.append(
- urllib.parse.urljoin(
- base_url,
- f"/.well-known/oauth-authorization-server{parsed.path.rstrip('/')}",
- )
- )
- urls.append(
- urllib.parse.urljoin(
- base_url, f"/.well-known/openid-configuration{parsed.path.rstrip('/')}"
- )
- )
- return urls
- # TODO: Some OAuth providers require Initial Access Tokens (IATs) for dynamic client registration.
- # This is not currently supported.
- async def get_oauth_client_info_with_dynamic_client_registration(
- request,
- client_id: str,
- oauth_server_url: str,
- oauth_server_key: Optional[str] = None,
- ) -> OAuthClientInformationFull:
- try:
- oauth_server_metadata = None
- oauth_server_metadata_url = None
- redirect_base_url = (
- str(request.app.state.config.WEBUI_URL or request.base_url)
- ).rstrip("/")
- oauth_client_metadata = OAuthClientMetadata(
- client_name="Open WebUI",
- redirect_uris=[f"{redirect_base_url}/oauth/clients/{client_id}/callback"],
- grant_types=["authorization_code", "refresh_token"],
- response_types=["code"],
- token_endpoint_auth_method="client_secret_post",
- )
- # Attempt to fetch OAuth server metadata to get registration endpoint & scopes
- discovery_urls = get_discovery_urls(oauth_server_url)
- for url in discovery_urls:
- async with aiohttp.ClientSession() as session:
- async with session.get(
- url, ssl=AIOHTTP_CLIENT_SESSION_SSL
- ) as oauth_server_metadata_response:
- if oauth_server_metadata_response.status == 200:
- try:
- oauth_server_metadata = OAuthMetadata.model_validate(
- await oauth_server_metadata_response.json()
- )
- oauth_server_metadata_url = url
- if (
- oauth_client_metadata.scope is None
- and oauth_server_metadata.scopes_supported is not None
- ):
- oauth_client_metadata.scope = " ".join(
- oauth_server_metadata.scopes_supported
- )
- break
- except Exception as e:
- log.error(f"Error parsing OAuth metadata from {url}: {e}")
- continue
- registration_url = None
- if oauth_server_metadata and oauth_server_metadata.registration_endpoint:
- registration_url = str(oauth_server_metadata.registration_endpoint)
- else:
- _, base_url = get_parsed_and_base_url(oauth_server_url)
- registration_url = urllib.parse.urljoin(base_url, "/register")
- registration_data = oauth_client_metadata.model_dump(
- exclude_none=True,
- mode="json",
- by_alias=True,
- )
- # Perform dynamic client registration and return client info
- async with aiohttp.ClientSession() as session:
- async with session.post(
- registration_url, json=registration_data, ssl=AIOHTTP_CLIENT_SESSION_SSL
- ) as oauth_client_registration_response:
- try:
- registration_response_json = (
- await oauth_client_registration_response.json()
- )
- oauth_client_info = OAuthClientInformationFull.model_validate(
- {
- **registration_response_json,
- **{"issuer": oauth_server_metadata_url},
- }
- )
- log.info(
- f"Dynamic client registration successful at {registration_url}, client_id: {oauth_client_info.client_id}"
- )
- return oauth_client_info
- except Exception as e:
- error_text = None
- try:
- error_text = await oauth_client_registration_response.text()
- log.error(
- f"Dynamic client registration failed at {registration_url}: {oauth_client_registration_response.status} - {error_text}"
- )
- except Exception as e:
- pass
- log.error(f"Error parsing client registration response: {e}")
- raise Exception(
- f"Dynamic client registration failed: {error_text}"
- if error_text
- else "Error parsing client registration response"
- )
- raise Exception("Dynamic client registration failed")
- except Exception as e:
- log.error(f"Exception during dynamic client registration: {e}")
- raise e
- class OAuthClientManager:
- def __init__(self, app):
- self.oauth = OAuth()
- self.app = app
- self.clients = {}
- def add_client(self, client_id, oauth_client_info: OAuthClientInformationFull):
- self.clients[client_id] = {
- "client": self.oauth.register(
- name=client_id,
- client_id=oauth_client_info.client_id,
- client_secret=oauth_client_info.client_secret,
- client_kwargs=(
- {"scope": oauth_client_info.scope}
- if oauth_client_info.scope
- else {}
- ),
- server_metadata_url=(
- oauth_client_info.issuer if oauth_client_info.issuer else None
- ),
- ),
- "client_info": oauth_client_info,
- }
- return self.clients[client_id]
- def remove_client(self, client_id):
- if client_id in self.clients:
- del self.clients[client_id]
- log.info(f"Removed OAuth client {client_id}")
- return True
- def get_client(self, client_id):
- client = self.clients.get(client_id)
- return client["client"] if client else None
- def get_client_info(self, client_id):
- client = self.clients.get(client_id)
- return client["client_info"] if client else None
- def get_server_metadata_url(self, client_id):
- if client_id in self.clients:
- client = self.clients[client_id]
- return (
- client.server_metadata_url
- if hasattr(client, "server_metadata_url")
- else None
- )
- return None
- async def get_oauth_token(
- self, user_id: str, client_id: str, force_refresh: bool = False
- ):
- """
- Get a valid OAuth token for the user, automatically refreshing if needed.
- Args:
- user_id: The user ID
- client_id: The OAuth client ID (provider)
- force_refresh: Force token refresh even if current token appears valid
- Returns:
- dict: OAuth token data with access_token, or None if no valid token available
- """
- try:
- # Get the OAuth session
- session = OAuthSessions.get_session_by_provider_and_user_id(
- client_id, user_id
- )
- if not session:
- log.warning(
- f"No OAuth session found for user {user_id}, client_id {client_id}"
- )
- return None
- if force_refresh or datetime.now() + timedelta(
- minutes=5
- ) >= datetime.fromtimestamp(session.expires_at):
- log.debug(
- f"Token refresh needed for user {user_id}, client_id {session.provider}"
- )
- refreshed_token = await self._refresh_token(session)
- if refreshed_token:
- return refreshed_token
- else:
- log.warning(
- f"Token refresh failed for user {user_id}, client_id {session.provider}, deleting session {session.id}"
- )
- OAuthSessions.delete_session_by_id(session.id)
- return None
- return session.token
- except Exception as e:
- log.error(f"Error getting OAuth token for user {user_id}: {e}")
- return None
- async def _refresh_token(self, session) -> dict:
- """
- Refresh an OAuth token if needed, with concurrency protection.
- Args:
- session: The OAuth session object
- Returns:
- dict: Refreshed token data, or None if refresh failed
- """
- try:
- # Perform the actual refresh
- refreshed_token = await self._perform_token_refresh(session)
- if refreshed_token:
- # Update the session with new token data
- session = OAuthSessions.update_session_by_id(
- session.id, refreshed_token
- )
- log.info(f"Successfully refreshed token for session {session.id}")
- return session.token
- else:
- log.error(f"Failed to refresh token for session {session.id}")
- return None
- except Exception as e:
- log.error(f"Error refreshing token for session {session.id}: {e}")
- return None
- async def _perform_token_refresh(self, session) -> dict:
- """
- Perform the actual OAuth token refresh.
- Args:
- session: The OAuth session object
- Returns:
- dict: New token data, or None if refresh failed
- """
- client_id = session.provider
- token_data = session.token
- if not token_data.get("refresh_token"):
- log.warning(f"No refresh token available for session {session.id}")
- return None
- try:
- client = self.get_client(client_id)
- if not client:
- log.error(f"No OAuth client found for provider {client_id}")
- return None
- token_endpoint = None
- async with aiohttp.ClientSession(trust_env=True) as session_http:
- async with session_http.get(
- self.get_server_metadata_url(client_id)
- ) as r:
- if r.status == 200:
- openid_data = await r.json()
- token_endpoint = openid_data.get("token_endpoint")
- else:
- log.error(
- f"Failed to fetch OpenID configuration for client_id {client_id}"
- )
- if not token_endpoint:
- log.error(f"No token endpoint found for client_id {client_id}")
- return None
- # Prepare refresh request
- refresh_data = {
- "grant_type": "refresh_token",
- "refresh_token": token_data["refresh_token"],
- "client_id": client.client_id,
- }
- if hasattr(client, "client_secret") and client.client_secret:
- refresh_data["client_secret"] = client.client_secret
- # Make refresh request
- async with aiohttp.ClientSession(trust_env=True) as session_http:
- async with session_http.post(
- token_endpoint,
- data=refresh_data,
- headers={"Content-Type": "application/x-www-form-urlencoded"},
- ssl=AIOHTTP_CLIENT_SESSION_SSL,
- ) as r:
- if r.status == 200:
- new_token_data = await r.json()
- # Merge with existing token data (preserve refresh_token if not provided)
- if "refresh_token" not in new_token_data:
- new_token_data["refresh_token"] = token_data[
- "refresh_token"
- ]
- # Add timestamp for tracking
- new_token_data["issued_at"] = datetime.now().timestamp()
- # Calculate expires_at if we have expires_in
- if (
- "expires_in" in new_token_data
- and "expires_at" not in new_token_data
- ):
- new_token_data["expires_at"] = int(
- datetime.now().timestamp()
- + new_token_data["expires_in"]
- )
- log.debug(f"Token refresh successful for client_id {client_id}")
- return new_token_data
- else:
- error_text = await r.text()
- log.error(
- f"Token refresh failed for client_id {client_id}: {r.status} - {error_text}"
- )
- return None
- except Exception as e:
- log.error(f"Exception during token refresh for client_id {client_id}: {e}")
- return None
- async def handle_authorize(self, request, client_id: str) -> RedirectResponse:
- client = self.get_client(client_id)
- if client is None:
- raise HTTPException(404)
- client_info = self.get_client_info(client_id)
- if client_info is None:
- raise HTTPException(404)
- redirect_uri = (
- client_info.redirect_uris[0] if client_info.redirect_uris else None
- )
- return await client.authorize_redirect(request, str(redirect_uri))
- async def handle_callback(self, request, client_id: str, user_id: str, response):
- client = self.get_client(client_id)
- if client is None:
- raise HTTPException(404)
- error_message = None
- try:
- token = await client.authorize_access_token(request)
- if token:
- try:
- # Add timestamp for tracking
- token["issued_at"] = datetime.now().timestamp()
- # Calculate expires_at if we have expires_in
- if "expires_in" in token and "expires_at" not in token:
- token["expires_at"] = (
- datetime.now().timestamp() + token["expires_in"]
- )
- # Clean up any existing sessions for this user/client_id first
- sessions = OAuthSessions.get_sessions_by_user_id(user_id)
- for session in sessions:
- if session.provider == client_id:
- OAuthSessions.delete_session_by_id(session.id)
- session = OAuthSessions.create_session(
- user_id=user_id,
- provider=client_id,
- token=token,
- )
- log.info(
- f"Stored OAuth session server-side for user {user_id}, client_id {client_id}"
- )
- except Exception as e:
- error_message = "Failed to store OAuth session server-side"
- log.error(f"Failed to store OAuth session server-side: {e}")
- else:
- error_message = "Failed to obtain OAuth token"
- log.warning(error_message)
- except Exception as e:
- error_message = "OAuth callback error"
- log.warning(f"OAuth callback error: {e}")
- redirect_url = (
- str(request.app.state.config.WEBUI_URL or request.base_url)
- ).rstrip("/")
- if error_message:
- log.debug(error_message)
- redirect_url = f"{redirect_url}/?error={error_message}"
- return RedirectResponse(url=redirect_url, headers=response.headers)
- response = RedirectResponse(url=redirect_url, headers=response.headers)
- return response
- class OAuthManager:
- def __init__(self, app):
- self.oauth = OAuth()
- self.app = app
- self._clients = {}
- for name, provider_config in OAUTH_PROVIDERS.items():
- if "register" not in provider_config:
- log.error(f"OAuth provider {name} missing register function")
- continue
- client = provider_config["register"](self.oauth)
- self._clients[name] = client
- def get_client(self, provider_name):
- if provider_name not in self._clients:
- self._clients[provider_name] = self.oauth.create_client(provider_name)
- return self._clients[provider_name]
- def get_server_metadata_url(self, provider_name):
- if provider_name in self._clients:
- client = self._clients[provider_name]
- return (
- client.server_metadata_url
- if hasattr(client, "server_metadata_url")
- else None
- )
- return None
- async def get_oauth_token(
- self, user_id: str, session_id: str, force_refresh: bool = False
- ):
- """
- Get a valid OAuth token for the user, automatically refreshing if needed.
- Args:
- user_id: The user ID
- provider: Optional provider name. If None, gets the most recent session.
- force_refresh: Force token refresh even if current token appears valid
- Returns:
- dict: OAuth token data with access_token, or None if no valid token available
- """
- try:
- # Get the OAuth session
- session = OAuthSessions.get_session_by_id_and_user_id(session_id, user_id)
- if not session:
- log.warning(
- f"No OAuth session found for user {user_id}, session {session_id}"
- )
- return None
- if force_refresh or datetime.now() + timedelta(
- minutes=5
- ) >= datetime.fromtimestamp(session.expires_at):
- log.debug(
- f"Token refresh needed for user {user_id}, provider {session.provider}"
- )
- refreshed_token = await self._refresh_token(session)
- if refreshed_token:
- return refreshed_token
- else:
- log.warning(
- f"Token refresh failed for user {user_id}, provider {session.provider}, deleting session {session.id}"
- )
- OAuthSessions.delete_session_by_id(session.id)
- return None
- return session.token
- except Exception as e:
- log.error(f"Error getting OAuth token for user {user_id}: {e}")
- return None
- async def _refresh_token(self, session) -> dict:
- """
- Refresh an OAuth token if needed, with concurrency protection.
- Args:
- session: The OAuth session object
- Returns:
- dict: Refreshed token data, or None if refresh failed
- """
- try:
- # Perform the actual refresh
- refreshed_token = await self._perform_token_refresh(session)
- if refreshed_token:
- # Update the session with new token data
- session = OAuthSessions.update_session_by_id(
- session.id, refreshed_token
- )
- log.info(f"Successfully refreshed token for session {session.id}")
- return session.token
- else:
- log.error(f"Failed to refresh token for session {session.id}")
- return None
- except Exception as e:
- log.error(f"Error refreshing token for session {session.id}: {e}")
- return None
- async def _perform_token_refresh(self, session) -> dict:
- """
- Perform the actual OAuth token refresh.
- Args:
- session: The OAuth session object
- Returns:
- dict: New token data, or None if refresh failed
- """
- provider = session.provider
- token_data = session.token
- if not token_data.get("refresh_token"):
- log.warning(f"No refresh token available for session {session.id}")
- return None
- try:
- client = self.get_client(provider)
- if not client:
- log.error(f"No OAuth client found for provider {provider}")
- return None
- server_metadata_url = self.get_server_metadata_url(provider)
- token_endpoint = None
- async with aiohttp.ClientSession(trust_env=True) as session_http:
- async with session_http.get(server_metadata_url) as r:
- if r.status == 200:
- openid_data = await r.json()
- token_endpoint = openid_data.get("token_endpoint")
- else:
- log.error(
- f"Failed to fetch OpenID configuration for provider {provider}"
- )
- if not token_endpoint:
- log.error(f"No token endpoint found for provider {provider}")
- return None
- # Prepare refresh request
- refresh_data = {
- "grant_type": "refresh_token",
- "refresh_token": token_data["refresh_token"],
- "client_id": client.client_id,
- }
- # Add client_secret if available (some providers require it)
- if hasattr(client, "client_secret") and client.client_secret:
- refresh_data["client_secret"] = client.client_secret
- # Make refresh request
- async with aiohttp.ClientSession(trust_env=True) as session_http:
- async with session_http.post(
- token_endpoint,
- data=refresh_data,
- headers={"Content-Type": "application/x-www-form-urlencoded"},
- ssl=AIOHTTP_CLIENT_SESSION_SSL,
- ) as r:
- if r.status == 200:
- new_token_data = await r.json()
- # Merge with existing token data (preserve refresh_token if not provided)
- if "refresh_token" not in new_token_data:
- new_token_data["refresh_token"] = token_data[
- "refresh_token"
- ]
- # Add timestamp for tracking
- new_token_data["issued_at"] = datetime.now().timestamp()
- # Calculate expires_at if we have expires_in
- if (
- "expires_in" in new_token_data
- and "expires_at" not in new_token_data
- ):
- new_token_data["expires_at"] = int(
- datetime.now().timestamp()
- + new_token_data["expires_in"]
- )
- log.debug(f"Token refresh successful for provider {provider}")
- return new_token_data
- else:
- error_text = await r.text()
- log.error(
- f"Token refresh failed for provider {provider}: {r.status} - {error_text}"
- )
- return None
- except Exception as e:
- log.error(f"Exception during token refresh for provider {provider}: {e}")
- return None
- def get_user_role(self, user, user_data):
- user_count = Users.get_num_users()
- if user and user_count == 1:
- # If the user is the only user, assign the role "admin" - actually repairs role for single user on login
- log.debug("Assigning the only user the admin role")
- return "admin"
- if not user and user_count == 0:
- # If there are no users, assign the role "admin", as the first user will be an admin
- log.debug("Assigning the first user the admin role")
- return "admin"
- if auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT:
- log.debug("Running OAUTH Role management")
- oauth_claim = auth_manager_config.OAUTH_ROLES_CLAIM
- oauth_allowed_roles = auth_manager_config.OAUTH_ALLOWED_ROLES
- oauth_admin_roles = auth_manager_config.OAUTH_ADMIN_ROLES
- oauth_roles = []
- # Default/fallback role if no matching roles are found
- role = auth_manager_config.DEFAULT_USER_ROLE
- # Next block extracts the roles from the user data, accepting nested claims of any depth
- if oauth_claim and oauth_allowed_roles and oauth_admin_roles:
- claim_data = user_data
- nested_claims = oauth_claim.split(".")
- for nested_claim in nested_claims:
- claim_data = claim_data.get(nested_claim, {})
- oauth_roles = []
- if isinstance(claim_data, list):
- oauth_roles = claim_data
- if isinstance(claim_data, str) or isinstance(claim_data, int):
- oauth_roles = [str(claim_data)]
- log.debug(f"Oauth Roles claim: {oauth_claim}")
- log.debug(f"User roles from oauth: {oauth_roles}")
- log.debug(f"Accepted user roles: {oauth_allowed_roles}")
- log.debug(f"Accepted admin roles: {oauth_admin_roles}")
- # If any roles are found, check if they match the allowed or admin roles
- if oauth_roles:
- # If role management is enabled, and matching roles are provided, use the roles
- for allowed_role in oauth_allowed_roles:
- # If the user has any of the allowed roles, assign the role "user"
- if allowed_role in oauth_roles:
- log.debug("Assigned user the user role")
- role = "user"
- break
- for admin_role in oauth_admin_roles:
- # If the user has any of the admin roles, assign the role "admin"
- if admin_role in oauth_roles:
- log.debug("Assigned user the admin role")
- role = "admin"
- break
- else:
- if not user:
- # If role management is disabled, use the default role for new users
- role = auth_manager_config.DEFAULT_USER_ROLE
- else:
- # If role management is disabled, use the existing role for existing users
- role = user.role
- return role
- def update_user_groups(self, user, user_data, default_permissions):
- log.debug("Running OAUTH Group management")
- oauth_claim = auth_manager_config.OAUTH_GROUPS_CLAIM
- try:
- blocked_groups = json.loads(auth_manager_config.OAUTH_BLOCKED_GROUPS)
- except Exception as e:
- log.exception(f"Error loading OAUTH_BLOCKED_GROUPS: {e}")
- blocked_groups = []
- user_oauth_groups = []
- # Nested claim search for groups claim
- if oauth_claim:
- claim_data = user_data
- nested_claims = oauth_claim.split(".")
- for nested_claim in nested_claims:
- claim_data = claim_data.get(nested_claim, {})
- if isinstance(claim_data, list):
- user_oauth_groups = claim_data
- elif isinstance(claim_data, str):
- user_oauth_groups = [claim_data]
- else:
- user_oauth_groups = []
- user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id)
- all_available_groups: list[GroupModel] = Groups.get_groups()
- # Create groups if they don't exist and creation is enabled
- if auth_manager_config.ENABLE_OAUTH_GROUP_CREATION:
- log.debug("Checking for missing groups to create...")
- all_group_names = {g.name for g in all_available_groups}
- groups_created = False
- # Determine creator ID: Prefer admin, fallback to current user if no admin exists
- admin_user = Users.get_super_admin_user()
- creator_id = admin_user.id if admin_user else user.id
- log.debug(f"Using creator ID {creator_id} for potential group creation.")
- for group_name in user_oauth_groups:
- if group_name not in all_group_names:
- log.info(
- f"Group '{group_name}' not found via OAuth claim. Creating group..."
- )
- try:
- new_group_form = GroupForm(
- name=group_name,
- description=f"Group '{group_name}' created automatically via OAuth.",
- permissions=default_permissions, # Use default permissions from function args
- user_ids=[], # Start with no users, user will be added later by subsequent logic
- )
- # Use determined creator ID (admin or fallback to current user)
- created_group = Groups.insert_new_group(
- creator_id, new_group_form
- )
- if created_group:
- log.info(
- f"Successfully created group '{group_name}' with ID {created_group.id} using creator ID {creator_id}"
- )
- groups_created = True
- # Add to local set to prevent duplicate creation attempts in this run
- all_group_names.add(group_name)
- else:
- log.error(
- f"Failed to create group '{group_name}' via OAuth."
- )
- except Exception as e:
- log.error(f"Error creating group '{group_name}' via OAuth: {e}")
- # Refresh the list of all available groups if any were created
- if groups_created:
- all_available_groups = Groups.get_groups()
- log.debug("Refreshed list of all available groups after creation.")
- log.debug(f"Oauth Groups claim: {oauth_claim}")
- log.debug(f"User oauth groups: {user_oauth_groups}")
- log.debug(f"User's current groups: {[g.name for g in user_current_groups]}")
- log.debug(
- f"All groups available in OpenWebUI: {[g.name for g in all_available_groups]}"
- )
- # Remove groups that user is no longer a part of
- for group_model in user_current_groups:
- if (
- user_oauth_groups
- and group_model.name not in user_oauth_groups
- and not is_in_blocked_groups(group_model.name, blocked_groups)
- ):
- # Remove group from user
- log.debug(
- f"Removing user from group {group_model.name} as it is no longer in their oauth groups"
- )
- user_ids = group_model.user_ids
- user_ids = [i for i in user_ids if i != user.id]
- # In case a group is created, but perms are never assigned to the group by hitting "save"
- group_permissions = group_model.permissions
- if not group_permissions:
- group_permissions = default_permissions
- update_form = GroupUpdateForm(
- name=group_model.name,
- description=group_model.description,
- permissions=group_permissions,
- user_ids=user_ids,
- )
- Groups.update_group_by_id(
- id=group_model.id, form_data=update_form, overwrite=False
- )
- # Add user to new groups
- for group_model in all_available_groups:
- if (
- user_oauth_groups
- and group_model.name in user_oauth_groups
- and not any(gm.name == group_model.name for gm in user_current_groups)
- and not is_in_blocked_groups(group_model.name, blocked_groups)
- ):
- # Add user to group
- log.debug(
- f"Adding user to group {group_model.name} as it was found in their oauth groups"
- )
- user_ids = group_model.user_ids
- user_ids.append(user.id)
- # In case a group is created, but perms are never assigned to the group by hitting "save"
- group_permissions = group_model.permissions
- if not group_permissions:
- group_permissions = default_permissions
- update_form = GroupUpdateForm(
- name=group_model.name,
- description=group_model.description,
- permissions=group_permissions,
- user_ids=user_ids,
- )
- Groups.update_group_by_id(
- id=group_model.id, form_data=update_form, overwrite=False
- )
- async def _process_picture_url(
- self, picture_url: str, access_token: str = None
- ) -> str:
- """Process a picture URL and return a base64 encoded data URL.
- Args:
- picture_url: The URL of the picture to process
- access_token: Optional OAuth access token for authenticated requests
- Returns:
- A data URL containing the base64 encoded picture, or "/user.png" if processing fails
- """
- if not picture_url:
- return "/user.png"
- try:
- get_kwargs = {}
- if access_token:
- get_kwargs["headers"] = {
- "Authorization": f"Bearer {access_token}",
- }
- async with aiohttp.ClientSession(trust_env=True) as session:
- async with session.get(
- picture_url, **get_kwargs, ssl=AIOHTTP_CLIENT_SESSION_SSL
- ) as resp:
- if resp.ok:
- picture = await resp.read()
- base64_encoded_picture = base64.b64encode(picture).decode(
- "utf-8"
- )
- guessed_mime_type = mimetypes.guess_type(picture_url)[0]
- if guessed_mime_type is None:
- guessed_mime_type = "image/jpeg"
- return (
- f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
- )
- else:
- log.warning(
- f"Failed to fetch profile picture from {picture_url}"
- )
- return "/user.png"
- except Exception as e:
- log.error(f"Error processing profile picture '{picture_url}': {e}")
- return "/user.png"
- async def handle_login(self, request, provider):
- if provider not in OAUTH_PROVIDERS:
- raise HTTPException(404)
- # If the provider has a custom redirect URL, use that, otherwise automatically generate one
- redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for(
- "oauth_login_callback", provider=provider
- )
- client = self.get_client(provider)
- if client is None:
- raise HTTPException(404)
- return await client.authorize_redirect(request, redirect_uri)
- async def handle_callback(self, request, provider, response):
- if provider not in OAUTH_PROVIDERS:
- raise HTTPException(404)
- error_message = None
- try:
- client = self.get_client(provider)
- try:
- token = await client.authorize_access_token(request)
- except Exception as e:
- log.warning(f"OAuth callback error: {e}")
- raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
- # Try to get userinfo from the token first, some providers include it there
- user_data: UserInfo = token.get("userinfo")
- if (
- (not user_data)
- or (auth_manager_config.OAUTH_EMAIL_CLAIM not in user_data)
- or (auth_manager_config.OAUTH_USERNAME_CLAIM not in user_data)
- ):
- user_data: UserInfo = await client.userinfo(token=token)
- if (
- provider == "feishu"
- and isinstance(user_data, dict)
- and "data" in user_data
- ):
- user_data = user_data["data"]
- if not user_data:
- log.warning(f"OAuth callback failed, user data is missing: {token}")
- raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
- # Extract the "sub" claim, using custom claim if configured
- if auth_manager_config.OAUTH_SUB_CLAIM:
- sub = user_data.get(auth_manager_config.OAUTH_SUB_CLAIM)
- else:
- # Fallback to the default sub claim if not configured
- sub = user_data.get(OAUTH_PROVIDERS[provider].get("sub_claim", "sub"))
- if not sub:
- log.warning(f"OAuth callback failed, sub is missing: {user_data}")
- raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
- provider_sub = f"{provider}@{sub}"
- # Email extraction
- email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM
- email = user_data.get(email_claim, "")
- # We currently mandate that email addresses are provided
- if not email:
- # If the provider is GitHub,and public email is not provided, we can use the access token to fetch the user's email
- if provider == "github":
- try:
- access_token = token.get("access_token")
- headers = {"Authorization": f"Bearer {access_token}"}
- async with aiohttp.ClientSession(trust_env=True) as session:
- async with session.get(
- "https://api.github.com/user/emails",
- headers=headers,
- ssl=AIOHTTP_CLIENT_SESSION_SSL,
- ) as resp:
- if resp.ok:
- emails = await resp.json()
- # use the primary email as the user's email
- primary_email = next(
- (
- e["email"]
- for e in emails
- if e.get("primary")
- ),
- None,
- )
- if primary_email:
- email = primary_email
- else:
- log.warning(
- "No primary email found in GitHub response"
- )
- raise HTTPException(
- 400, detail=ERROR_MESSAGES.INVALID_CRED
- )
- else:
- log.warning("Failed to fetch GitHub email")
- raise HTTPException(
- 400, detail=ERROR_MESSAGES.INVALID_CRED
- )
- except Exception as e:
- log.warning(f"Error fetching GitHub email: {e}")
- raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
- else:
- log.warning(f"OAuth callback failed, email is missing: {user_data}")
- raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
- email = email.lower()
- # If allowed domains are configured, check if the email domain is in the list
- if (
- "*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
- and email.split("@")[-1]
- not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
- ):
- log.warning(
- f"OAuth callback failed, e-mail domain is not in the list of allowed domains: {user_data}"
- )
- raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
- # Check if the user exists
- user = Users.get_user_by_oauth_sub(provider_sub)
- if not user:
- # If the user does not exist, check if merging is enabled
- if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL:
- # Check if the user exists by email
- user = Users.get_user_by_email(email)
- if user:
- # Update the user with the new oauth sub
- Users.update_user_oauth_sub_by_id(user.id, provider_sub)
- if user:
- determined_role = self.get_user_role(user, user_data)
- if user.role != determined_role:
- Users.update_user_role_by_id(user.id, determined_role)
- # Update profile picture if enabled and different from current
- if auth_manager_config.OAUTH_UPDATE_PICTURE_ON_LOGIN:
- picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM
- if picture_claim:
- new_picture_url = user_data.get(
- picture_claim,
- OAUTH_PROVIDERS[provider].get("picture_url", ""),
- )
- processed_picture_url = await self._process_picture_url(
- new_picture_url, token.get("access_token")
- )
- if processed_picture_url != user.profile_image_url:
- Users.update_user_profile_image_url_by_id(
- user.id, processed_picture_url
- )
- log.debug(f"Updated profile picture for user {user.email}")
- else:
- # If the user does not exist, check if signups are enabled
- if auth_manager_config.ENABLE_OAUTH_SIGNUP:
- # Check if an existing user with the same email already exists
- existing_user = Users.get_user_by_email(email)
- if existing_user:
- raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
- picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM
- if picture_claim:
- picture_url = user_data.get(
- picture_claim,
- OAUTH_PROVIDERS[provider].get("picture_url", ""),
- )
- picture_url = await self._process_picture_url(
- picture_url, token.get("access_token")
- )
- else:
- picture_url = "/user.png"
- username_claim = auth_manager_config.OAUTH_USERNAME_CLAIM
- name = user_data.get(username_claim)
- if not name:
- log.warning("Username claim is missing, using email as name")
- name = email
- user = Auths.insert_new_auth(
- email=email,
- password=get_password_hash(
- str(uuid.uuid4())
- ), # Random password, not used
- name=name,
- profile_image_url=picture_url,
- role=self.get_user_role(None, user_data),
- oauth_sub=provider_sub,
- )
- if auth_manager_config.WEBHOOK_URL:
- await post_webhook(
- WEBUI_NAME,
- auth_manager_config.WEBHOOK_URL,
- WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
- {
- "action": "signup",
- "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
- "user": user.model_dump_json(exclude_none=True),
- },
- )
- else:
- raise HTTPException(
- status.HTTP_403_FORBIDDEN,
- detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
- )
- jwt_token = create_token(
- data={"id": user.id},
- expires_delta=parse_duration(auth_manager_config.JWT_EXPIRES_IN),
- )
- if (
- auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT
- and user.role != "admin"
- ):
- self.update_user_groups(
- user=user,
- user_data=user_data,
- default_permissions=request.app.state.config.USER_PERMISSIONS,
- )
- except Exception as e:
- log.error(f"Error during OAuth process: {e}")
- error_message = (
- e.detail
- if isinstance(e, HTTPException) and e.detail
- else ERROR_MESSAGES.DEFAULT("Error during OAuth process")
- )
- redirect_base_url = (
- str(request.app.state.config.WEBUI_URL or request.base_url)
- ).rstrip("/")
- redirect_url = f"{redirect_base_url}/auth"
- if error_message:
- redirect_url = f"{redirect_url}?error={error_message}"
- return RedirectResponse(url=redirect_url, headers=response.headers)
- response = RedirectResponse(url=redirect_url, headers=response.headers)
- # Set the cookie token
- # Redirect back to the frontend with the JWT token
- response.set_cookie(
- key="token",
- value=jwt_token,
- httponly=False, # Required for frontend access
- samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
- secure=WEBUI_AUTH_COOKIE_SECURE,
- )
- # Legacy cookies for compatibility with older frontend versions
- if ENABLE_OAUTH_ID_TOKEN_COOKIE:
- response.set_cookie(
- key="oauth_id_token",
- value=token.get("id_token"),
- httponly=True,
- samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
- secure=WEBUI_AUTH_COOKIE_SECURE,
- )
- try:
- # Add timestamp for tracking
- token["issued_at"] = datetime.now().timestamp()
- # Calculate expires_at if we have expires_in
- if "expires_in" in token and "expires_at" not in token:
- token["expires_at"] = datetime.now().timestamp() + token["expires_in"]
- # Clean up any existing sessions for this user/provider first
- sessions = OAuthSessions.get_sessions_by_user_id(user.id)
- for session in sessions:
- if session.provider == provider:
- OAuthSessions.delete_session_by_id(session.id)
- session = OAuthSessions.create_session(
- user_id=user.id,
- provider=provider,
- token=token,
- )
- response.set_cookie(
- key="oauth_session_id",
- value=session.id,
- httponly=True,
- samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
- secure=WEBUI_AUTH_COOKIE_SECURE,
- )
- log.info(
- f"Stored OAuth session server-side for user {user.id}, provider {provider}"
- )
- except Exception as e:
- log.error(f"Failed to store OAuth session server-side: {e}")
- return response
|