12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346 |
- import base64
- import hashlib
- import logging
- import mimetypes
- import sys
- import urllib
- import uuid
- import json
- from datetime import datetime, timedelta
- import re
- import fnmatch
- import time
- import secrets
- from cryptography.fernet import Fernet
- import aiohttp
- from authlib.integrations.starlette_client import OAuth
- from authlib.oidc.core import UserInfo
- from fastapi import (
- HTTPException,
- status,
- )
- from starlette.responses import RedirectResponse
- from typing import Optional
- from open_webui.models.auths import Auths
- from open_webui.models.oauth_sessions import OAuthSessions
- from open_webui.models.users import Users
- from open_webui.models.groups import Groups, GroupModel, GroupUpdateForm, GroupForm
- from open_webui.config import (
- DEFAULT_USER_ROLE,
- ENABLE_OAUTH_SIGNUP,
- OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
- OAUTH_PROVIDERS,
- ENABLE_OAUTH_ROLE_MANAGEMENT,
- ENABLE_OAUTH_GROUP_MANAGEMENT,
- ENABLE_OAUTH_GROUP_CREATION,
- OAUTH_BLOCKED_GROUPS,
- OAUTH_ROLES_CLAIM,
- OAUTH_SUB_CLAIM,
- OAUTH_GROUPS_CLAIM,
- OAUTH_EMAIL_CLAIM,
- OAUTH_PICTURE_CLAIM,
- OAUTH_USERNAME_CLAIM,
- OAUTH_ALLOWED_ROLES,
- OAUTH_ADMIN_ROLES,
- OAUTH_ALLOWED_DOMAINS,
- OAUTH_UPDATE_PICTURE_ON_LOGIN,
- WEBHOOK_URL,
- JWT_EXPIRES_IN,
- AppConfig,
- )
- from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
- from open_webui.env import (
- AIOHTTP_CLIENT_SESSION_SSL,
- WEBUI_NAME,
- WEBUI_AUTH_COOKIE_SAME_SITE,
- WEBUI_AUTH_COOKIE_SECURE,
- ENABLE_OAUTH_ID_TOKEN_COOKIE,
- OAUTH_CLIENT_INFO_ENCRYPTION_KEY,
- )
- from open_webui.utils.misc import parse_duration
- from open_webui.utils.auth import get_password_hash, create_token
- from open_webui.utils.webhook import post_webhook
- from mcp.shared.auth import (
- OAuthClientMetadata,
- OAuthMetadata,
- )
- class OAuthClientInformationFull(OAuthClientMetadata):
- issuer: Optional[str] = None # URL of the OAuth server that issued this client
- client_id: str
- client_secret: str | None = None
- client_id_issued_at: int | None = None
- client_secret_expires_at: int | None = None
- from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
- logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
- log = logging.getLogger(__name__)
- log.setLevel(SRC_LOG_LEVELS["OAUTH"])
- auth_manager_config = AppConfig()
- auth_manager_config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
- auth_manager_config.ENABLE_OAUTH_SIGNUP = ENABLE_OAUTH_SIGNUP
- auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL = OAUTH_MERGE_ACCOUNTS_BY_EMAIL
- auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT
- auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT = ENABLE_OAUTH_GROUP_MANAGEMENT
- auth_manager_config.ENABLE_OAUTH_GROUP_CREATION = ENABLE_OAUTH_GROUP_CREATION
- auth_manager_config.OAUTH_BLOCKED_GROUPS = OAUTH_BLOCKED_GROUPS
- auth_manager_config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
- auth_manager_config.OAUTH_SUB_CLAIM = OAUTH_SUB_CLAIM
- auth_manager_config.OAUTH_GROUPS_CLAIM = OAUTH_GROUPS_CLAIM
- auth_manager_config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
- auth_manager_config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
- auth_manager_config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
- auth_manager_config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES
- auth_manager_config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES
- auth_manager_config.OAUTH_ALLOWED_DOMAINS = OAUTH_ALLOWED_DOMAINS
- auth_manager_config.WEBHOOK_URL = WEBHOOK_URL
- auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
- auth_manager_config.OAUTH_UPDATE_PICTURE_ON_LOGIN = OAUTH_UPDATE_PICTURE_ON_LOGIN
- FERNET = None
- if len(OAUTH_CLIENT_INFO_ENCRYPTION_KEY) != 44:
- key_bytes = hashlib.sha256(OAUTH_CLIENT_INFO_ENCRYPTION_KEY.encode()).digest()
- OAUTH_CLIENT_INFO_ENCRYPTION_KEY = base64.urlsafe_b64encode(key_bytes)
- else:
- OAUTH_CLIENT_INFO_ENCRYPTION_KEY = OAUTH_CLIENT_INFO_ENCRYPTION_KEY.encode()
- try:
- FERNET = Fernet(OAUTH_CLIENT_INFO_ENCRYPTION_KEY)
- except Exception as e:
- log.error(f"Error initializing Fernet with provided key: {e}")
- raise
- def encrypt_data(data) -> str:
- """Encrypt data for storage"""
- try:
- data_json = json.dumps(data)
- encrypted = FERNET.encrypt(data_json.encode()).decode()
- return encrypted
- except Exception as e:
- log.error(f"Error encrypting data: {e}")
- raise
- def decrypt_data(data: str):
- """Decrypt data from storage"""
- try:
- decrypted = FERNET.decrypt(data.encode()).decode()
- return json.loads(decrypted)
- except Exception as e:
- log.error(f"Error decrypting data: {e}")
- raise
- def is_in_blocked_groups(group_name: str, groups: list) -> bool:
- """
- Check if a group name matches any blocked pattern.
- Supports exact matches, shell-style wildcards (*, ?), and regex patterns.
- Args:
- group_name: The group name to check
- groups: List of patterns to match against
- Returns:
- True if the group is blocked, False otherwise
- """
- if not groups:
- return False
- for group_pattern in groups:
- if not group_pattern: # Skip empty patterns
- continue
- # Exact match
- if group_name == group_pattern:
- return True
- # Try as regex pattern first if it contains regex-specific characters
- if any(
- char in group_pattern
- for char in ["^", "$", "[", "]", "(", ")", "{", "}", "+", "\\", "|"]
- ):
- try:
- # Use the original pattern as-is for regex matching
- if re.search(group_pattern, group_name):
- return True
- except re.error:
- # If regex is invalid, fall through to wildcard check
- pass
- # Shell-style wildcard match (supports * and ?)
- if "*" in group_pattern or "?" in group_pattern:
- if fnmatch.fnmatch(group_name, group_pattern):
- return True
- return False
- def get_parsed_and_base_url(server_url) -> tuple[urllib.parse.ParseResult, str]:
- parsed = urllib.parse.urlparse(server_url)
- base_url = f"{parsed.scheme}://{parsed.netloc}"
- return parsed, base_url
- def get_discovery_urls(server_url) -> list[str]:
- parsed, base_url = get_parsed_and_base_url(server_url)
- urls = [
- urllib.parse.urljoin(base_url, "/.well-known/oauth-authorization-server"),
- urllib.parse.urljoin(base_url, "/.well-known/openid-configuration"),
- ]
- if parsed.path and parsed.path != "/":
- urls.append(
- urllib.parse.urljoin(
- base_url,
- f"/.well-known/oauth-authorization-server{parsed.path.rstrip('/')}",
- )
- )
- urls.append(
- urllib.parse.urljoin(
- base_url, f"/.well-known/openid-configuration{parsed.path.rstrip('/')}"
- )
- )
- return urls
- # TODO: Some OAuth providers require Initial Access Tokens (IATs) for dynamic client registration.
- # This is not currently supported.
- async def get_oauth_client_info_with_dynamic_client_registration(
- request,
- client_id: str,
- oauth_server_url: str,
- oauth_server_key: Optional[str] = None,
- ) -> OAuthClientInformationFull:
- try:
- oauth_server_metadata = None
- oauth_server_metadata_url = None
- redirect_base_url = (
- str(request.app.state.config.WEBUI_URL or request.base_url)
- ).rstrip("/")
- oauth_client_metadata = OAuthClientMetadata(
- client_name="Open WebUI",
- redirect_uris=[f"{redirect_base_url}/oauth/clients/{client_id}/callback"],
- grant_types=["authorization_code", "refresh_token"],
- response_types=["code"],
- token_endpoint_auth_method="client_secret_post",
- )
- # Attempt to fetch OAuth server metadata to get registration endpoint & scopes
- discovery_urls = get_discovery_urls(oauth_server_url)
- for url in discovery_urls:
- async with aiohttp.ClientSession() as session:
- async with session.get(
- url, ssl=AIOHTTP_CLIENT_SESSION_SSL
- ) as oauth_server_metadata_response:
- if oauth_server_metadata_response.status == 200:
- try:
- oauth_server_metadata = OAuthMetadata.model_validate(
- await oauth_server_metadata_response.json()
- )
- oauth_server_metadata_url = url
- if (
- oauth_client_metadata.scope is None
- and oauth_server_metadata.scopes_supported is not None
- ):
- oauth_client_metadata.scope = " ".join(
- oauth_server_metadata.scopes_supported
- )
- break
- except Exception as e:
- log.error(f"Error parsing OAuth metadata from {url}: {e}")
- continue
- registration_url = None
- if oauth_server_metadata and oauth_server_metadata.registration_endpoint:
- registration_url = str(oauth_server_metadata.registration_endpoint)
- else:
- _, base_url = get_parsed_and_base_url(oauth_server_url)
- registration_url = urllib.parse.urljoin(base_url, "/register")
- registration_data = oauth_client_metadata.model_dump(
- exclude_none=True,
- mode="json",
- by_alias=True,
- )
- # Perform dynamic client registration and return client info
- async with aiohttp.ClientSession() as session:
- async with session.post(
- registration_url, json=registration_data, ssl=AIOHTTP_CLIENT_SESSION_SSL
- ) as oauth_client_registration_response:
- try:
- registration_response_json = (
- await oauth_client_registration_response.json()
- )
- oauth_client_info = OAuthClientInformationFull.model_validate(
- {
- **registration_response_json,
- **{"issuer": oauth_server_metadata_url},
- }
- )
- log.info(
- f"Dynamic client registration successful at {registration_url}, client_id: {oauth_client_info.client_id}"
- )
- return oauth_client_info
- except Exception as e:
- error_text = None
- try:
- error_text = await oauth_client_registration_response.text()
- log.error(
- f"Dynamic client registration failed at {registration_url}: {oauth_client_registration_response.status} - {error_text}"
- )
- except Exception as e:
- pass
- log.error(f"Error parsing client registration response: {e}")
- raise Exception(
- f"Dynamic client registration failed: {error_text}"
- if error_text
- else "Error parsing client registration response"
- )
- raise Exception("Dynamic client registration failed")
- except Exception as e:
- log.error(f"Exception during dynamic client registration: {e}")
- raise e
- class OAuthClientManager:
- def __init__(self, app):
- self.oauth = OAuth()
- self.app = app
- self.clients = {}
- def add_client(self, client_id, oauth_client_info: OAuthClientInformationFull):
- self.clients[client_id] = {
- "client": self.oauth.register(
- name=client_id,
- client_id=oauth_client_info.client_id,
- client_secret=oauth_client_info.client_secret,
- client_kwargs=(
- {"scope": oauth_client_info.scope}
- if oauth_client_info.scope
- else {}
- ),
- server_metadata_url=(
- oauth_client_info.issuer if oauth_client_info.issuer else None
- ),
- ),
- "client_info": oauth_client_info,
- }
- return self.clients[client_id]
- def remove_client(self, client_id):
- if client_id in self.clients:
- del self.clients[client_id]
- log.info(f"Removed OAuth client {client_id}")
- return True
- def get_client(self, client_id):
- client = self.clients.get(client_id)
- return client["client"] if client else None
- def get_client_info(self, client_id):
- client = self.clients.get(client_id)
- return client["client_info"] if client else None
- def get_server_metadata_url(self, client_id):
- if client_id in self.clients:
- client = self.clients[client_id]
- return (
- client.server_metadata_url
- if hasattr(client, "server_metadata_url")
- else None
- )
- return None
- async def get_oauth_token(
- self, user_id: str, client_id: str, force_refresh: bool = False
- ):
- """
- Get a valid OAuth token for the user, automatically refreshing if needed.
- Args:
- user_id: The user ID
- client_id: The OAuth client ID (provider)
- force_refresh: Force token refresh even if current token appears valid
- Returns:
- dict: OAuth token data with access_token, or None if no valid token available
- """
- try:
- # Get the OAuth session
- session = OAuthSessions.get_session_by_provider_and_user_id(
- client_id, user_id
- )
- if not session:
- log.warning(
- f"No OAuth session found for user {user_id}, client_id {client_id}"
- )
- return None
- if force_refresh or datetime.now() + timedelta(
- minutes=5
- ) >= datetime.fromtimestamp(session.expires_at):
- log.debug(
- f"Token refresh needed for user {user_id}, client_id {session.provider}"
- )
- refreshed_token = await self._refresh_token(session)
- if refreshed_token:
- return refreshed_token
- else:
- log.warning(
- f"Token refresh failed for user {user_id}, client_id {session.provider}, deleting session {session.id}"
- )
- OAuthSessions.delete_session_by_id(session.id)
- return None
- return session.token
- except Exception as e:
- log.error(f"Error getting OAuth token for user {user_id}: {e}")
- return None
- async def _refresh_token(self, session) -> dict:
- """
- Refresh an OAuth token if needed, with concurrency protection.
- Args:
- session: The OAuth session object
- Returns:
- dict: Refreshed token data, or None if refresh failed
- """
- try:
- # Perform the actual refresh
- refreshed_token = await self._perform_token_refresh(session)
- if refreshed_token:
- # Update the session with new token data
- session = OAuthSessions.update_session_by_id(
- session.id, refreshed_token
- )
- log.info(f"Successfully refreshed token for session {session.id}")
- return session.token
- else:
- log.error(f"Failed to refresh token for session {session.id}")
- return None
- except Exception as e:
- log.error(f"Error refreshing token for session {session.id}: {e}")
- return None
- async def _perform_token_refresh(self, session) -> dict:
- """
- Perform the actual OAuth token refresh.
- Args:
- session: The OAuth session object
- Returns:
- dict: New token data, or None if refresh failed
- """
- client_id = session.provider
- token_data = session.token
- if not token_data.get("refresh_token"):
- log.warning(f"No refresh token available for session {session.id}")
- return None
- try:
- client = self.get_client(client_id)
- if not client:
- log.error(f"No OAuth client found for provider {client_id}")
- return None
- token_endpoint = None
- async with aiohttp.ClientSession(trust_env=True) as session_http:
- async with session_http.get(
- self.get_server_metadata_url(client_id)
- ) as r:
- if r.status == 200:
- openid_data = await r.json()
- token_endpoint = openid_data.get("token_endpoint")
- else:
- log.error(
- f"Failed to fetch OpenID configuration for client_id {client_id}"
- )
- if not token_endpoint:
- log.error(f"No token endpoint found for client_id {client_id}")
- return None
- # Prepare refresh request
- refresh_data = {
- "grant_type": "refresh_token",
- "refresh_token": token_data["refresh_token"],
- "client_id": client.client_id,
- }
- if hasattr(client, "client_secret") and client.client_secret:
- refresh_data["client_secret"] = client.client_secret
- # Make refresh request
- async with aiohttp.ClientSession(trust_env=True) as session_http:
- async with session_http.post(
- token_endpoint,
- data=refresh_data,
- headers={"Content-Type": "application/x-www-form-urlencoded"},
- ssl=AIOHTTP_CLIENT_SESSION_SSL,
- ) as r:
- if r.status == 200:
- new_token_data = await r.json()
- # Merge with existing token data (preserve refresh_token if not provided)
- if "refresh_token" not in new_token_data:
- new_token_data["refresh_token"] = token_data[
- "refresh_token"
- ]
- # Add timestamp for tracking
- new_token_data["issued_at"] = datetime.now().timestamp()
- # Calculate expires_at if we have expires_in
- if (
- "expires_in" in new_token_data
- and "expires_at" not in new_token_data
- ):
- new_token_data["expires_at"] = int(
- datetime.now().timestamp()
- + new_token_data["expires_in"]
- )
- log.debug(f"Token refresh successful for client_id {client_id}")
- return new_token_data
- else:
- error_text = await r.text()
- log.error(
- f"Token refresh failed for client_id {client_id}: {r.status} - {error_text}"
- )
- return None
- except Exception as e:
- log.error(f"Exception during token refresh for client_id {client_id}: {e}")
- return None
- async def handle_authorize(self, request, client_id: str) -> RedirectResponse:
- client = self.get_client(client_id)
- if client is None:
- raise HTTPException(404)
- client_info = self.get_client_info(client_id)
- if client_info is None:
- raise HTTPException(404)
- redirect_uri = (
- client_info.redirect_uris[0] if client_info.redirect_uris else None
- )
- return await client.authorize_redirect(request, str(redirect_uri))
- async def handle_callback(self, request, client_id: str, user_id: str, response):
- client = self.get_client(client_id)
- if client is None:
- raise HTTPException(404)
- error_message = None
- try:
- token = await client.authorize_access_token(request)
- if token:
- try:
- # Add timestamp for tracking
- token["issued_at"] = datetime.now().timestamp()
- # Calculate expires_at if we have expires_in
- if "expires_in" in token and "expires_at" not in token:
- token["expires_at"] = (
- datetime.now().timestamp() + token["expires_in"]
- )
- # Clean up any existing sessions for this user/client_id first
- sessions = OAuthSessions.get_sessions_by_user_id(user_id)
- for session in sessions:
- if session.provider == client_id:
- OAuthSessions.delete_session_by_id(session.id)
- session = OAuthSessions.create_session(
- user_id=user_id,
- provider=client_id,
- token=token,
- )
- log.info(
- f"Stored OAuth session server-side for user {user_id}, client_id {client_id}"
- )
- except Exception as e:
- error_message = "Failed to store OAuth session server-side"
- log.error(f"Failed to store OAuth session server-side: {e}")
- else:
- error_message = "Failed to obtain OAuth token"
- log.warning(error_message)
- except Exception as e:
- error_message = "OAuth callback error"
- log.warning(f"OAuth callback error: {e}")
- redirect_url = (
- str(request.app.state.config.WEBUI_URL or request.base_url)
- ).rstrip("/")
- if error_message:
- log.debug(error_message)
- redirect_url = f"{redirect_url}/?error={error_message}"
- return RedirectResponse(url=redirect_url, headers=response.headers)
- response = RedirectResponse(url=redirect_url, headers=response.headers)
- return response
- class OAuthManager:
- def __init__(self, app):
- self.oauth = OAuth()
- self.app = app
- self._clients = {}
- for _, provider_config in OAUTH_PROVIDERS.items():
- if "register" not in provider_config:
- log.error(
- f"OAuth provider {provider_config['name']} missing register function"
- )
- continue
- client = provider_config["register"](self.oauth)
- self._clients[provider_config["name"]] = client
- def get_client(self, provider_name):
- if provider_name not in self._clients:
- self._clients[provider_name] = self.oauth.create_client(provider_name)
- return self._clients[provider_name]
- def get_server_metadata_url(self, provider_name):
- if provider_name in self._clients:
- client = self._clients[provider_name]
- return (
- client.server_metadata_url
- if hasattr(client, "server_metadata_url")
- else None
- )
- return None
- async def get_oauth_token(
- self, user_id: str, session_id: str, force_refresh: bool = False
- ):
- """
- Get a valid OAuth token for the user, automatically refreshing if needed.
- Args:
- user_id: The user ID
- provider: Optional provider name. If None, gets the most recent session.
- force_refresh: Force token refresh even if current token appears valid
- Returns:
- dict: OAuth token data with access_token, or None if no valid token available
- """
- try:
- # Get the OAuth session
- session = OAuthSessions.get_session_by_id_and_user_id(session_id, user_id)
- if not session:
- log.warning(
- f"No OAuth session found for user {user_id}, session {session_id}"
- )
- return None
- if force_refresh or datetime.now() + timedelta(
- minutes=5
- ) >= datetime.fromtimestamp(session.expires_at):
- log.debug(
- f"Token refresh needed for user {user_id}, provider {session.provider}"
- )
- refreshed_token = await self._refresh_token(session)
- if refreshed_token:
- return refreshed_token
- else:
- log.warning(
- f"Token refresh failed for user {user_id}, provider {session.provider}, deleting session {session.id}"
- )
- OAuthSessions.delete_session_by_id(session.id)
- return None
- return session.token
- except Exception as e:
- log.error(f"Error getting OAuth token for user {user_id}: {e}")
- return None
- async def _refresh_token(self, session) -> dict:
- """
- Refresh an OAuth token if needed, with concurrency protection.
- Args:
- session: The OAuth session object
- Returns:
- dict: Refreshed token data, or None if refresh failed
- """
- try:
- # Perform the actual refresh
- refreshed_token = await self._perform_token_refresh(session)
- if refreshed_token:
- # Update the session with new token data
- session = OAuthSessions.update_session_by_id(
- session.id, refreshed_token
- )
- log.info(f"Successfully refreshed token for session {session.id}")
- return session.token
- else:
- log.error(f"Failed to refresh token for session {session.id}")
- return None
- except Exception as e:
- log.error(f"Error refreshing token for session {session.id}: {e}")
- return None
- async def _perform_token_refresh(self, session) -> dict:
- """
- Perform the actual OAuth token refresh.
- Args:
- session: The OAuth session object
- Returns:
- dict: New token data, or None if refresh failed
- """
- provider = session.provider
- token_data = session.token
- if not token_data.get("refresh_token"):
- log.warning(f"No refresh token available for session {session.id}")
- return None
- try:
- client = self.get_client(provider)
- if not client:
- log.error(f"No OAuth client found for provider {provider}")
- return None
- server_metadata_url = self.get_server_metadata_url(provider)
- token_endpoint = None
- async with aiohttp.ClientSession(trust_env=True) as session_http:
- async with session_http.get(server_metadata_url) as r:
- if r.status == 200:
- openid_data = await r.json()
- token_endpoint = openid_data.get("token_endpoint")
- else:
- log.error(
- f"Failed to fetch OpenID configuration for provider {provider}"
- )
- if not token_endpoint:
- log.error(f"No token endpoint found for provider {provider}")
- return None
- # Prepare refresh request
- refresh_data = {
- "grant_type": "refresh_token",
- "refresh_token": token_data["refresh_token"],
- "client_id": client.client_id,
- }
- # Add client_secret if available (some providers require it)
- if hasattr(client, "client_secret") and client.client_secret:
- refresh_data["client_secret"] = client.client_secret
- # Make refresh request
- async with aiohttp.ClientSession(trust_env=True) as session_http:
- async with session_http.post(
- token_endpoint,
- data=refresh_data,
- headers={"Content-Type": "application/x-www-form-urlencoded"},
- ssl=AIOHTTP_CLIENT_SESSION_SSL,
- ) as r:
- if r.status == 200:
- new_token_data = await r.json()
- # Merge with existing token data (preserve refresh_token if not provided)
- if "refresh_token" not in new_token_data:
- new_token_data["refresh_token"] = token_data[
- "refresh_token"
- ]
- # Add timestamp for tracking
- new_token_data["issued_at"] = datetime.now().timestamp()
- # Calculate expires_at if we have expires_in
- if (
- "expires_in" in new_token_data
- and "expires_at" not in new_token_data
- ):
- new_token_data["expires_at"] = int(
- datetime.now().timestamp()
- + new_token_data["expires_in"]
- )
- log.debug(f"Token refresh successful for provider {provider}")
- return new_token_data
- else:
- error_text = await r.text()
- log.error(
- f"Token refresh failed for provider {provider}: {r.status} - {error_text}"
- )
- return None
- except Exception as e:
- log.error(f"Exception during token refresh for provider {provider}: {e}")
- return None
- def get_user_role(self, user, user_data):
- user_count = Users.get_num_users()
- if user and user_count == 1:
- # If the user is the only user, assign the role "admin" - actually repairs role for single user on login
- log.debug("Assigning the only user the admin role")
- return "admin"
- if not user and user_count == 0:
- # If there are no users, assign the role "admin", as the first user will be an admin
- log.debug("Assigning the first user the admin role")
- return "admin"
- if auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT:
- log.debug("Running OAUTH Role management")
- oauth_claim = auth_manager_config.OAUTH_ROLES_CLAIM
- oauth_allowed_roles = auth_manager_config.OAUTH_ALLOWED_ROLES
- oauth_admin_roles = auth_manager_config.OAUTH_ADMIN_ROLES
- oauth_roles = []
- # Default/fallback role if no matching roles are found
- role = auth_manager_config.DEFAULT_USER_ROLE
- # Next block extracts the roles from the user data, accepting nested claims of any depth
- if oauth_claim and oauth_allowed_roles and oauth_admin_roles:
- claim_data = user_data
- nested_claims = oauth_claim.split(".")
- for nested_claim in nested_claims:
- claim_data = claim_data.get(nested_claim, {})
- oauth_roles = []
- if isinstance(claim_data, list):
- oauth_roles = claim_data
- if isinstance(claim_data, str) or isinstance(claim_data, int):
- oauth_roles = [str(claim_data)]
- log.debug(f"Oauth Roles claim: {oauth_claim}")
- log.debug(f"User roles from oauth: {oauth_roles}")
- log.debug(f"Accepted user roles: {oauth_allowed_roles}")
- log.debug(f"Accepted admin roles: {oauth_admin_roles}")
- # If any roles are found, check if they match the allowed or admin roles
- if oauth_roles:
- # If role management is enabled, and matching roles are provided, use the roles
- for allowed_role in oauth_allowed_roles:
- # If the user has any of the allowed roles, assign the role "user"
- if allowed_role in oauth_roles:
- log.debug("Assigned user the user role")
- role = "user"
- break
- for admin_role in oauth_admin_roles:
- # If the user has any of the admin roles, assign the role "admin"
- if admin_role in oauth_roles:
- log.debug("Assigned user the admin role")
- role = "admin"
- break
- else:
- if not user:
- # If role management is disabled, use the default role for new users
- role = auth_manager_config.DEFAULT_USER_ROLE
- else:
- # If role management is disabled, use the existing role for existing users
- role = user.role
- return role
- def update_user_groups(self, user, user_data, default_permissions):
- log.debug("Running OAUTH Group management")
- oauth_claim = auth_manager_config.OAUTH_GROUPS_CLAIM
- try:
- blocked_groups = json.loads(auth_manager_config.OAUTH_BLOCKED_GROUPS)
- except Exception as e:
- log.exception(f"Error loading OAUTH_BLOCKED_GROUPS: {e}")
- blocked_groups = []
- user_oauth_groups = []
- # Nested claim search for groups claim
- if oauth_claim:
- claim_data = user_data
- nested_claims = oauth_claim.split(".")
- for nested_claim in nested_claims:
- claim_data = claim_data.get(nested_claim, {})
- if isinstance(claim_data, list):
- user_oauth_groups = claim_data
- elif isinstance(claim_data, str):
- user_oauth_groups = [claim_data]
- else:
- user_oauth_groups = []
- user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id)
- all_available_groups: list[GroupModel] = Groups.get_groups()
- # Create groups if they don't exist and creation is enabled
- if auth_manager_config.ENABLE_OAUTH_GROUP_CREATION:
- log.debug("Checking for missing groups to create...")
- all_group_names = {g.name for g in all_available_groups}
- groups_created = False
- # Determine creator ID: Prefer admin, fallback to current user if no admin exists
- admin_user = Users.get_super_admin_user()
- creator_id = admin_user.id if admin_user else user.id
- log.debug(f"Using creator ID {creator_id} for potential group creation.")
- for group_name in user_oauth_groups:
- if group_name not in all_group_names:
- log.info(
- f"Group '{group_name}' not found via OAuth claim. Creating group..."
- )
- try:
- new_group_form = GroupForm(
- name=group_name,
- description=f"Group '{group_name}' created automatically via OAuth.",
- permissions=default_permissions, # Use default permissions from function args
- user_ids=[], # Start with no users, user will be added later by subsequent logic
- )
- # Use determined creator ID (admin or fallback to current user)
- created_group = Groups.insert_new_group(
- creator_id, new_group_form
- )
- if created_group:
- log.info(
- f"Successfully created group '{group_name}' with ID {created_group.id} using creator ID {creator_id}"
- )
- groups_created = True
- # Add to local set to prevent duplicate creation attempts in this run
- all_group_names.add(group_name)
- else:
- log.error(
- f"Failed to create group '{group_name}' via OAuth."
- )
- except Exception as e:
- log.error(f"Error creating group '{group_name}' via OAuth: {e}")
- # Refresh the list of all available groups if any were created
- if groups_created:
- all_available_groups = Groups.get_groups()
- log.debug("Refreshed list of all available groups after creation.")
- log.debug(f"Oauth Groups claim: {oauth_claim}")
- log.debug(f"User oauth groups: {user_oauth_groups}")
- log.debug(f"User's current groups: {[g.name for g in user_current_groups]}")
- log.debug(
- f"All groups available in OpenWebUI: {[g.name for g in all_available_groups]}"
- )
- # Remove groups that user is no longer a part of
- for group_model in user_current_groups:
- if (
- user_oauth_groups
- and group_model.name not in user_oauth_groups
- and not is_in_blocked_groups(group_model.name, blocked_groups)
- ):
- # Remove group from user
- log.debug(
- f"Removing user from group {group_model.name} as it is no longer in their oauth groups"
- )
- user_ids = group_model.user_ids
- user_ids = [i for i in user_ids if i != user.id]
- # In case a group is created, but perms are never assigned to the group by hitting "save"
- group_permissions = group_model.permissions
- if not group_permissions:
- group_permissions = default_permissions
- update_form = GroupUpdateForm(
- name=group_model.name,
- description=group_model.description,
- permissions=group_permissions,
- user_ids=user_ids,
- )
- Groups.update_group_by_id(
- id=group_model.id, form_data=update_form, overwrite=False
- )
- # Add user to new groups
- for group_model in all_available_groups:
- if (
- user_oauth_groups
- and group_model.name in user_oauth_groups
- and not any(gm.name == group_model.name for gm in user_current_groups)
- and not is_in_blocked_groups(group_model.name, blocked_groups)
- ):
- # Add user to group
- log.debug(
- f"Adding user to group {group_model.name} as it was found in their oauth groups"
- )
- user_ids = group_model.user_ids
- user_ids.append(user.id)
- # In case a group is created, but perms are never assigned to the group by hitting "save"
- group_permissions = group_model.permissions
- if not group_permissions:
- group_permissions = default_permissions
- update_form = GroupUpdateForm(
- name=group_model.name,
- description=group_model.description,
- permissions=group_permissions,
- user_ids=user_ids,
- )
- Groups.update_group_by_id(
- id=group_model.id, form_data=update_form, overwrite=False
- )
- async def _process_picture_url(
- self, picture_url: str, access_token: str = None
- ) -> str:
- """Process a picture URL and return a base64 encoded data URL.
- Args:
- picture_url: The URL of the picture to process
- access_token: Optional OAuth access token for authenticated requests
- Returns:
- A data URL containing the base64 encoded picture, or "/user.png" if processing fails
- """
- if not picture_url:
- return "/user.png"
- try:
- get_kwargs = {}
- if access_token:
- get_kwargs["headers"] = {
- "Authorization": f"Bearer {access_token}",
- }
- async with aiohttp.ClientSession(trust_env=True) as session:
- async with session.get(
- picture_url, **get_kwargs, ssl=AIOHTTP_CLIENT_SESSION_SSL
- ) as resp:
- if resp.ok:
- picture = await resp.read()
- base64_encoded_picture = base64.b64encode(picture).decode(
- "utf-8"
- )
- guessed_mime_type = mimetypes.guess_type(picture_url)[0]
- if guessed_mime_type is None:
- guessed_mime_type = "image/jpeg"
- return (
- f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
- )
- else:
- log.warning(
- f"Failed to fetch profile picture from {picture_url}"
- )
- return "/user.png"
- except Exception as e:
- log.error(f"Error processing profile picture '{picture_url}': {e}")
- return "/user.png"
- async def handle_login(self, request, provider):
- if provider not in OAUTH_PROVIDERS:
- raise HTTPException(404)
- # If the provider has a custom redirect URL, use that, otherwise automatically generate one
- redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for(
- "oauth_login_callback", provider=provider
- )
- client = self.get_client(provider)
- if client is None:
- raise HTTPException(404)
- return await client.authorize_redirect(request, redirect_uri)
- async def handle_callback(self, request, provider, response):
- if provider not in OAUTH_PROVIDERS:
- raise HTTPException(404)
- error_message = None
- try:
- client = self.get_client(provider)
- try:
- token = await client.authorize_access_token(request)
- except Exception as e:
- log.warning(f"OAuth callback error: {e}")
- raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
- # Try to get userinfo from the token first, some providers include it there
- user_data: UserInfo = token.get("userinfo")
- if (
- (not user_data)
- or (auth_manager_config.OAUTH_EMAIL_CLAIM not in user_data)
- or (auth_manager_config.OAUTH_USERNAME_CLAIM not in user_data)
- ):
- user_data: UserInfo = await client.userinfo(token=token)
- if (
- provider == "feishu"
- and isinstance(user_data, dict)
- and "data" in user_data
- ):
- user_data = user_data["data"]
- if not user_data:
- log.warning(f"OAuth callback failed, user data is missing: {token}")
- raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
- # Extract the "sub" claim, using custom claim if configured
- if auth_manager_config.OAUTH_SUB_CLAIM:
- sub = user_data.get(auth_manager_config.OAUTH_SUB_CLAIM)
- else:
- # Fallback to the default sub claim if not configured
- sub = user_data.get(OAUTH_PROVIDERS[provider].get("sub_claim", "sub"))
- if not sub:
- log.warning(f"OAuth callback failed, sub is missing: {user_data}")
- raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
- provider_sub = f"{provider}@{sub}"
- # Email extraction
- email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM
- email = user_data.get(email_claim, "")
- # We currently mandate that email addresses are provided
- if not email:
- # If the provider is GitHub,and public email is not provided, we can use the access token to fetch the user's email
- if provider == "github":
- try:
- access_token = token.get("access_token")
- headers = {"Authorization": f"Bearer {access_token}"}
- async with aiohttp.ClientSession(trust_env=True) as session:
- async with session.get(
- "https://api.github.com/user/emails",
- headers=headers,
- ssl=AIOHTTP_CLIENT_SESSION_SSL,
- ) as resp:
- if resp.ok:
- emails = await resp.json()
- # use the primary email as the user's email
- primary_email = next(
- (
- e["email"]
- for e in emails
- if e.get("primary")
- ),
- None,
- )
- if primary_email:
- email = primary_email
- else:
- log.warning(
- "No primary email found in GitHub response"
- )
- raise HTTPException(
- 400, detail=ERROR_MESSAGES.INVALID_CRED
- )
- else:
- log.warning("Failed to fetch GitHub email")
- raise HTTPException(
- 400, detail=ERROR_MESSAGES.INVALID_CRED
- )
- except Exception as e:
- log.warning(f"Error fetching GitHub email: {e}")
- raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
- else:
- log.warning(f"OAuth callback failed, email is missing: {user_data}")
- raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
- email = email.lower()
- # If allowed domains are configured, check if the email domain is in the list
- if (
- "*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
- and email.split("@")[-1]
- not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
- ):
- log.warning(
- f"OAuth callback failed, e-mail domain is not in the list of allowed domains: {user_data}"
- )
- raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
- # Check if the user exists
- user = Users.get_user_by_oauth_sub(provider_sub)
- if not user:
- # If the user does not exist, check if merging is enabled
- if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL:
- # Check if the user exists by email
- user = Users.get_user_by_email(email)
- if user:
- # Update the user with the new oauth sub
- Users.update_user_oauth_sub_by_id(user.id, provider_sub)
- if user:
- determined_role = self.get_user_role(user, user_data)
- if user.role != determined_role:
- Users.update_user_role_by_id(user.id, determined_role)
- # Update profile picture if enabled and different from current
- if auth_manager_config.OAUTH_UPDATE_PICTURE_ON_LOGIN:
- picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM
- if picture_claim:
- new_picture_url = user_data.get(
- picture_claim,
- OAUTH_PROVIDERS[provider].get("picture_url", ""),
- )
- processed_picture_url = await self._process_picture_url(
- new_picture_url, token.get("access_token")
- )
- if processed_picture_url != user.profile_image_url:
- Users.update_user_profile_image_url_by_id(
- user.id, processed_picture_url
- )
- log.debug(f"Updated profile picture for user {user.email}")
- else:
- # If the user does not exist, check if signups are enabled
- if auth_manager_config.ENABLE_OAUTH_SIGNUP:
- # Check if an existing user with the same email already exists
- existing_user = Users.get_user_by_email(email)
- if existing_user:
- raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
- picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM
- if picture_claim:
- picture_url = user_data.get(
- picture_claim,
- OAUTH_PROVIDERS[provider].get("picture_url", ""),
- )
- picture_url = await self._process_picture_url(
- picture_url, token.get("access_token")
- )
- else:
- picture_url = "/user.png"
- username_claim = auth_manager_config.OAUTH_USERNAME_CLAIM
- name = user_data.get(username_claim)
- if not name:
- log.warning("Username claim is missing, using email as name")
- name = email
- user = Auths.insert_new_auth(
- email=email,
- password=get_password_hash(
- str(uuid.uuid4())
- ), # Random password, not used
- name=name,
- profile_image_url=picture_url,
- role=self.get_user_role(None, user_data),
- oauth_sub=provider_sub,
- )
- if auth_manager_config.WEBHOOK_URL:
- await post_webhook(
- WEBUI_NAME,
- auth_manager_config.WEBHOOK_URL,
- WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
- {
- "action": "signup",
- "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
- "user": user.model_dump_json(exclude_none=True),
- },
- )
- else:
- raise HTTPException(
- status.HTTP_403_FORBIDDEN,
- detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
- )
- jwt_token = create_token(
- data={"id": user.id},
- expires_delta=parse_duration(auth_manager_config.JWT_EXPIRES_IN),
- )
- if (
- auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT
- and user.role != "admin"
- ):
- self.update_user_groups(
- user=user,
- user_data=user_data,
- default_permissions=request.app.state.config.USER_PERMISSIONS,
- )
- except Exception as e:
- log.error(f"Error during OAuth process: {e}")
- error_message = (
- e.detail
- if isinstance(e, HTTPException) and e.detail
- else ERROR_MESSAGES.DEFAULT("Error during OAuth process")
- )
- redirect_base_url = (
- str(request.app.state.config.WEBUI_URL or request.base_url)
- ).rstrip("/")
- redirect_url = f"{redirect_base_url}/auth"
- if error_message:
- redirect_url = f"{redirect_url}?error={error_message}"
- return RedirectResponse(url=redirect_url, headers=response.headers)
- response = RedirectResponse(url=redirect_url, headers=response.headers)
- # Set the cookie token
- # Redirect back to the frontend with the JWT token
- response.set_cookie(
- key="token",
- value=jwt_token,
- httponly=False, # Required for frontend access
- samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
- secure=WEBUI_AUTH_COOKIE_SECURE,
- )
- # Legacy cookies for compatibility with older frontend versions
- if ENABLE_OAUTH_ID_TOKEN_COOKIE:
- response.set_cookie(
- key="oauth_id_token",
- value=token.get("id_token"),
- httponly=True,
- samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
- secure=WEBUI_AUTH_COOKIE_SECURE,
- )
- try:
- # Add timestamp for tracking
- token["issued_at"] = datetime.now().timestamp()
- # Calculate expires_at if we have expires_in
- if "expires_in" in token and "expires_at" not in token:
- token["expires_at"] = datetime.now().timestamp() + token["expires_in"]
- # Clean up any existing sessions for this user/provider first
- sessions = OAuthSessions.get_sessions_by_user_id(user.id)
- for session in sessions:
- if session.provider == provider:
- OAuthSessions.delete_session_by_id(session.id)
- session = OAuthSessions.create_session(
- user_id=user.id,
- provider=provider,
- token=token,
- )
- response.set_cookie(
- key="oauth_session_id",
- value=session.id,
- httponly=True,
- samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
- secure=WEBUI_AUTH_COOKIE_SECURE,
- )
- log.info(
- f"Stored OAuth session server-side for user {user.id}, provider {provider}"
- )
- except Exception as e:
- log.error(f"Failed to store OAuth session server-side: {e}")
- return response
|