|
@@ -1,7 +1,9 @@
|
|
|
import base64
|
|
|
+import hashlib
|
|
|
import logging
|
|
|
import mimetypes
|
|
|
import sys
|
|
|
+import urllib
|
|
|
import uuid
|
|
|
import json
|
|
|
from datetime import datetime, timedelta
|
|
@@ -9,6 +11,9 @@ from datetime import datetime, timedelta
|
|
|
import re
|
|
|
import fnmatch
|
|
|
import time
|
|
|
+import secrets
|
|
|
+from cryptography.fernet import Fernet
|
|
|
+
|
|
|
|
|
|
import aiohttp
|
|
|
from authlib.integrations.starlette_client import OAuth
|
|
@@ -18,6 +23,7 @@ from fastapi import (
|
|
|
status,
|
|
|
)
|
|
|
from starlette.responses import RedirectResponse
|
|
|
+from typing import Optional
|
|
|
|
|
|
|
|
|
from open_webui.models.auths import Auths
|
|
@@ -56,11 +62,27 @@ from open_webui.env import (
|
|
|
WEBUI_AUTH_COOKIE_SAME_SITE,
|
|
|
WEBUI_AUTH_COOKIE_SECURE,
|
|
|
ENABLE_OAUTH_ID_TOKEN_COOKIE,
|
|
|
+ OAUTH_CLIENT_INFO_ENCRYPTION_KEY,
|
|
|
)
|
|
|
from open_webui.utils.misc import parse_duration
|
|
|
from open_webui.utils.auth import get_password_hash, create_token
|
|
|
from open_webui.utils.webhook import post_webhook
|
|
|
|
|
|
+from mcp.shared.auth import (
|
|
|
+ OAuthClientMetadata,
|
|
|
+ OAuthMetadata,
|
|
|
+)
|
|
|
+
|
|
|
+
|
|
|
+class OAuthClientInformationFull(OAuthClientMetadata):
|
|
|
+ issuer: Optional[str] = None # URL of the OAuth server that issued this client
|
|
|
+
|
|
|
+ client_id: str
|
|
|
+ client_secret: str | None = None
|
|
|
+ client_id_issued_at: int | None = None
|
|
|
+ client_secret_expires_at: int | None = None
|
|
|
+
|
|
|
+
|
|
|
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
|
|
|
|
|
|
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
|
@@ -89,6 +111,42 @@ auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
|
|
|
auth_manager_config.OAUTH_UPDATE_PICTURE_ON_LOGIN = OAUTH_UPDATE_PICTURE_ON_LOGIN
|
|
|
|
|
|
|
|
|
+FERNET = None
|
|
|
+
|
|
|
+if len(OAUTH_CLIENT_INFO_ENCRYPTION_KEY) != 44:
|
|
|
+ key_bytes = hashlib.sha256(OAUTH_CLIENT_INFO_ENCRYPTION_KEY.encode()).digest()
|
|
|
+ OAUTH_CLIENT_INFO_ENCRYPTION_KEY = base64.urlsafe_b64encode(key_bytes)
|
|
|
+else:
|
|
|
+ OAUTH_CLIENT_INFO_ENCRYPTION_KEY = OAUTH_CLIENT_INFO_ENCRYPTION_KEY.encode()
|
|
|
+
|
|
|
+try:
|
|
|
+ FERNET = Fernet(OAUTH_CLIENT_INFO_ENCRYPTION_KEY)
|
|
|
+except Exception as e:
|
|
|
+ log.error(f"Error initializing Fernet with provided key: {e}")
|
|
|
+ raise
|
|
|
+
|
|
|
+
|
|
|
+def encrypt_token(token) -> str:
|
|
|
+ """Encrypt OAuth tokens for storage"""
|
|
|
+ try:
|
|
|
+ token_json = json.dumps(token)
|
|
|
+ encrypted = FERNET.encrypt(token_json.encode()).decode()
|
|
|
+ return encrypted
|
|
|
+ except Exception as e:
|
|
|
+ log.error(f"Error encrypting tokens: {e}")
|
|
|
+ raise
|
|
|
+
|
|
|
+
|
|
|
+def decrypt_token(token: str):
|
|
|
+ """Decrypt OAuth tokens from storage"""
|
|
|
+ try:
|
|
|
+ decrypted = FERNET.decrypt(token.encode()).decode()
|
|
|
+ return json.loads(decrypted)
|
|
|
+ except Exception as e:
|
|
|
+ log.error(f"Error decrypting tokens: {e}")
|
|
|
+ raise
|
|
|
+
|
|
|
+
|
|
|
def is_in_blocked_groups(group_name: str, groups: list) -> bool:
|
|
|
"""
|
|
|
Check if a group name matches any blocked pattern.
|
|
@@ -133,6 +191,406 @@ def is_in_blocked_groups(group_name: str, groups: list) -> bool:
|
|
|
return False
|
|
|
|
|
|
|
|
|
+def get_parsed_and_base_url(server_url) -> tuple[urllib.parse.ParseResult, str]:
|
|
|
+ parsed = urllib.parse.urlparse(server_url)
|
|
|
+ base_url = f"{parsed.scheme}://{parsed.netloc}"
|
|
|
+ return parsed, base_url
|
|
|
+
|
|
|
+
|
|
|
+def get_discovery_urls(server_url) -> list[str]:
|
|
|
+ urls = []
|
|
|
+ parsed, base_url = get_parsed_and_base_url(server_url)
|
|
|
+
|
|
|
+ urls.append(
|
|
|
+ urllib.parse.urljoin(base_url, "/.well-known/oauth-authorization-server")
|
|
|
+ )
|
|
|
+ urls.append(urllib.parse.urljoin(base_url, "/.well-known/openid-configuration"))
|
|
|
+
|
|
|
+ return urls
|
|
|
+
|
|
|
+
|
|
|
+# TODO: Some OAuth providers require Initial Access Tokens (IATs) for dynamic client registration.
|
|
|
+# This is not currently supported.
|
|
|
+async def get_oauth_client_info_with_dynamic_client_registration(
|
|
|
+ request, oauth_server_url, oauth_server_key: Optional[str] = None
|
|
|
+) -> OAuthClientInformationFull:
|
|
|
+ try:
|
|
|
+ oauth_server_metadata = None
|
|
|
+ oauth_server_metadata_url = None
|
|
|
+
|
|
|
+ redirect_base_url = (
|
|
|
+ str(request.app.state.config.WEBUI_URL or request.base_url)
|
|
|
+ ).rstrip("/")
|
|
|
+ oauth_client_metadata = OAuthClientMetadata(
|
|
|
+ client_name="Open WebUI",
|
|
|
+ redirect_uris=[f"{redirect_base_url}/oauth/callback"],
|
|
|
+ grant_types=["authorization_code", "refresh_token"],
|
|
|
+ response_types=["code"],
|
|
|
+ token_endpoint_auth_method="client_secret_post",
|
|
|
+ )
|
|
|
+
|
|
|
+ # Attempt to fetch OAuth server metadata to get registration endpoint & scopes
|
|
|
+ discovery_urls = get_discovery_urls(oauth_server_url)
|
|
|
+ for url in discovery_urls:
|
|
|
+ async with aiohttp.ClientSession() as session:
|
|
|
+ async with session.get(
|
|
|
+ url, ssl=AIOHTTP_CLIENT_SESSION_SSL
|
|
|
+ ) as oauth_server_metadata_response:
|
|
|
+ if oauth_server_metadata_response.status == 200:
|
|
|
+ try:
|
|
|
+ oauth_server_metadata = OAuthMetadata.model_validate(
|
|
|
+ await oauth_server_metadata_response.json()
|
|
|
+ )
|
|
|
+ oauth_server_metadata_url = url
|
|
|
+ if (
|
|
|
+ oauth_client_metadata.scope is None
|
|
|
+ and oauth_server_metadata.scopes_supported is not None
|
|
|
+ ):
|
|
|
+ oauth_client_metadata.scope = " ".join(
|
|
|
+ oauth_server_metadata.scopes_supported
|
|
|
+ )
|
|
|
+ break
|
|
|
+ except Exception as e:
|
|
|
+ log.error(f"Error parsing OAuth metadata from {url}: {e}")
|
|
|
+ continue
|
|
|
+
|
|
|
+ registration_url = None
|
|
|
+ if oauth_server_metadata and oauth_server_metadata.registration_endpoint:
|
|
|
+ registration_url = str(oauth_server_metadata.registration_endpoint)
|
|
|
+ else:
|
|
|
+ _, base_url = get_parsed_and_base_url(oauth_server_url)
|
|
|
+ registration_url = urllib.parse.urljoin(base_url, "/register")
|
|
|
+
|
|
|
+ registration_data = oauth_client_metadata.model_dump(
|
|
|
+ exclude_none=True,
|
|
|
+ mode="json",
|
|
|
+ by_alias=True,
|
|
|
+ )
|
|
|
+
|
|
|
+ # Perform dynamic client registration and return client info
|
|
|
+ async with aiohttp.ClientSession() as session:
|
|
|
+ async with session.post(
|
|
|
+ registration_url, json=registration_data, ssl=AIOHTTP_CLIENT_SESSION_SSL
|
|
|
+ ) as oauth_client_registration_response:
|
|
|
+ try:
|
|
|
+ registration_response_json = (
|
|
|
+ await oauth_client_registration_response.json()
|
|
|
+ )
|
|
|
+ oauth_client_info = OAuthClientInformationFull.model_validate(
|
|
|
+ {
|
|
|
+ **registration_response_json,
|
|
|
+ **{"issuer": oauth_server_metadata_url},
|
|
|
+ }
|
|
|
+ )
|
|
|
+ log.info(
|
|
|
+ f"Dynamic client registration successful at {registration_url}, client_id: {oauth_client_info.client_id}"
|
|
|
+ )
|
|
|
+ return oauth_client_info
|
|
|
+ except Exception as e:
|
|
|
+ error_text = None
|
|
|
+ try:
|
|
|
+ error_text = await oauth_client_registration_response.text()
|
|
|
+ log.error(
|
|
|
+ f"Dynamic client registration failed at {registration_url}: {oauth_client_registration_response.status} - {error_text}"
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ pass
|
|
|
+
|
|
|
+ log.error(f"Error parsing client registration response: {e}")
|
|
|
+ raise Exception(
|
|
|
+ f"Dynamic client registration failed: {error_text}"
|
|
|
+ if error_text
|
|
|
+ else "Error parsing client registration response"
|
|
|
+ )
|
|
|
+ raise Exception("Dynamic client registration failed")
|
|
|
+ except Exception as e:
|
|
|
+ log.error(f"Exception during dynamic client registration: {e}")
|
|
|
+ raise e
|
|
|
+
|
|
|
+
|
|
|
+class OAuthClientManager:
|
|
|
+ def __init__(self, app):
|
|
|
+ self.oauth = OAuth()
|
|
|
+ self.app = app
|
|
|
+ self.clients = {}
|
|
|
+
|
|
|
+ def add_client(self, client_id, oauth_client_info: OAuthClientInformationFull):
|
|
|
+ if client_id not in self.clients:
|
|
|
+ self.clients[client_id] = {
|
|
|
+ "client": self.oauth.register(
|
|
|
+ name=client_id,
|
|
|
+ client_id=oauth_client_info.client_id,
|
|
|
+ client_secret=oauth_client_info.client_secret,
|
|
|
+ client_kwargs=(
|
|
|
+ {"scope": oauth_client_info.scope}
|
|
|
+ if oauth_client_info.scope
|
|
|
+ else {}
|
|
|
+ ),
|
|
|
+ server_metadata_url=(
|
|
|
+ oauth_client_info.issuer if oauth_client_info.issuer else None
|
|
|
+ ),
|
|
|
+ ),
|
|
|
+ "client_info": oauth_client_info,
|
|
|
+ }
|
|
|
+ return self.clients[client_id]
|
|
|
+
|
|
|
+ def remove_client(self, client_id):
|
|
|
+ if client_id in self.clients:
|
|
|
+ del self.clients[client_id]
|
|
|
+ log.info(f"Removed OAuth client {client_id}")
|
|
|
+ return True
|
|
|
+
|
|
|
+ def get_client(self, client_id):
|
|
|
+ client = self.clients.get(client_id)
|
|
|
+ return client["client"] if client else None
|
|
|
+
|
|
|
+ def get_client_info(self, client_id):
|
|
|
+ client = self.clients.get(client_id)
|
|
|
+ return client["client_info"] if client else None
|
|
|
+
|
|
|
+ def get_server_metadata_url(self, client_id):
|
|
|
+ if client_id in self.clients:
|
|
|
+ client = self.clients[client_id]
|
|
|
+ return (
|
|
|
+ client.server_metadata_url
|
|
|
+ if hasattr(client, "server_metadata_url")
|
|
|
+ else None
|
|
|
+ )
|
|
|
+ return None
|
|
|
+
|
|
|
+ async def get_oauth_token(
|
|
|
+ self, user_id: str, session_id: str, force_refresh: bool = False
|
|
|
+ ):
|
|
|
+ """
|
|
|
+ Get a valid OAuth token for the user, automatically refreshing if needed.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ user_id: The user ID
|
|
|
+ session_id: The OAuth session ID
|
|
|
+ force_refresh: Force token refresh even if current token appears valid
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ dict: OAuth token data with access_token, or None if no valid token available
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # Get the OAuth session
|
|
|
+ session = OAuthSessions.get_session_by_id_and_user_id(session_id, user_id)
|
|
|
+ if not session:
|
|
|
+ log.warning(
|
|
|
+ f"No OAuth session found for user {user_id}, session {session_id}"
|
|
|
+ )
|
|
|
+ return None
|
|
|
+
|
|
|
+ if force_refresh or datetime.now() + timedelta(
|
|
|
+ minutes=5
|
|
|
+ ) >= datetime.fromtimestamp(session.expires_at):
|
|
|
+ log.debug(
|
|
|
+ f"Token refresh needed for user {user_id}, client_id {session.provider}"
|
|
|
+ )
|
|
|
+ refreshed_token = await self._refresh_token(session)
|
|
|
+ if refreshed_token:
|
|
|
+ return refreshed_token
|
|
|
+ else:
|
|
|
+ log.warning(
|
|
|
+ f"Token refresh failed for user {user_id}, client_id {session.provider}"
|
|
|
+ )
|
|
|
+ return None
|
|
|
+ return session.token
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ log.error(f"Error getting OAuth token for user {user_id}: {e}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ async def _refresh_token(self, session) -> dict:
|
|
|
+ """
|
|
|
+ Refresh an OAuth token if needed, with concurrency protection.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ session: The OAuth session object
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ dict: Refreshed token data, or None if refresh failed
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # Perform the actual refresh
|
|
|
+ refreshed_token = await self._perform_token_refresh(session)
|
|
|
+
|
|
|
+ if refreshed_token:
|
|
|
+ # Update the session with new token data
|
|
|
+ session = OAuthSessions.update_session_by_id(
|
|
|
+ session.id, refreshed_token
|
|
|
+ )
|
|
|
+ log.info(f"Successfully refreshed token for session {session.id}")
|
|
|
+ return session.token
|
|
|
+ else:
|
|
|
+ log.error(f"Failed to refresh token for session {session.id}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ log.error(f"Error refreshing token for session {session.id}: {e}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ async def _perform_token_refresh(self, session) -> dict:
|
|
|
+ """
|
|
|
+ Perform the actual OAuth token refresh.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ session: The OAuth session object
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ dict: New token data, or None if refresh failed
|
|
|
+ """
|
|
|
+ client_id = session.provider
|
|
|
+ token_data = session.token
|
|
|
+
|
|
|
+ if not token_data.get("refresh_token"):
|
|
|
+ log.warning(f"No refresh token available for session {session.id}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ try:
|
|
|
+ client = self.get_client(client_id)
|
|
|
+ if not client:
|
|
|
+ log.error(f"No OAuth client found for provider {client_id}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ token_endpoint = None
|
|
|
+ async with aiohttp.ClientSession(trust_env=True) as session_http:
|
|
|
+ async with session_http.get(
|
|
|
+ self.get_server_metadata_url(client_id)
|
|
|
+ ) as r:
|
|
|
+ if r.status == 200:
|
|
|
+ openid_data = await r.json()
|
|
|
+ token_endpoint = openid_data.get("token_endpoint")
|
|
|
+ else:
|
|
|
+ log.error(
|
|
|
+ f"Failed to fetch OpenID configuration for client_id {client_id}"
|
|
|
+ )
|
|
|
+ if not token_endpoint:
|
|
|
+ log.error(f"No token endpoint found for client_id {client_id}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ # Prepare refresh request
|
|
|
+ refresh_data = {
|
|
|
+ "grant_type": "refresh_token",
|
|
|
+ "refresh_token": token_data["refresh_token"],
|
|
|
+ "client_id": client.client_id,
|
|
|
+ }
|
|
|
+ if hasattr(client, "client_secret") and client.client_secret:
|
|
|
+ refresh_data["client_secret"] = client.client_secret
|
|
|
+
|
|
|
+ # Make refresh request
|
|
|
+ async with aiohttp.ClientSession(trust_env=True) as session_http:
|
|
|
+ async with session_http.post(
|
|
|
+ token_endpoint,
|
|
|
+ data=refresh_data,
|
|
|
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|
|
+ ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
|
|
+ ) as r:
|
|
|
+ if r.status == 200:
|
|
|
+ new_token_data = await r.json()
|
|
|
+
|
|
|
+ # Merge with existing token data (preserve refresh_token if not provided)
|
|
|
+ if "refresh_token" not in new_token_data:
|
|
|
+ new_token_data["refresh_token"] = token_data[
|
|
|
+ "refresh_token"
|
|
|
+ ]
|
|
|
+
|
|
|
+ # Add timestamp for tracking
|
|
|
+ new_token_data["issued_at"] = datetime.now().timestamp()
|
|
|
+
|
|
|
+ # Calculate expires_at if we have expires_in
|
|
|
+ if (
|
|
|
+ "expires_in" in new_token_data
|
|
|
+ and "expires_at" not in new_token_data
|
|
|
+ ):
|
|
|
+ new_token_data["expires_at"] = int(
|
|
|
+ datetime.now().timestamp()
|
|
|
+ + new_token_data["expires_in"]
|
|
|
+ )
|
|
|
+
|
|
|
+ log.debug(f"Token refresh successful for client_id {client_id}")
|
|
|
+ return new_token_data
|
|
|
+ else:
|
|
|
+ error_text = await r.text()
|
|
|
+ log.error(
|
|
|
+ f"Token refresh failed for client_id {client_id}: {r.status} - {error_text}"
|
|
|
+ )
|
|
|
+ return None
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ log.error(f"Exception during token refresh for client_id {client_id}: {e}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ async def handle_authorize(self, request, client_id: str) -> RedirectResponse:
|
|
|
+ client = self.get_client(client_id)
|
|
|
+ if client is None:
|
|
|
+ raise HTTPException(404)
|
|
|
+
|
|
|
+ client_info = self.get_client_info(client_id)
|
|
|
+ if client_info is None:
|
|
|
+ raise HTTPException(404)
|
|
|
+
|
|
|
+ redirect_uri = (
|
|
|
+ client_info.redirect_uris[0] if client_info.redirect_uris else None
|
|
|
+ )
|
|
|
+ return await client.authorize_redirect(request, redirect_uri)
|
|
|
+
|
|
|
+ async def handle_callback(self, request, client_id: str, user_id: str, response):
|
|
|
+ client = self.get_client(client_id)
|
|
|
+ if client is None:
|
|
|
+ raise HTTPException(404)
|
|
|
+
|
|
|
+ error_message = None
|
|
|
+ try:
|
|
|
+ token = await client.authorize_access_token(request)
|
|
|
+ if token:
|
|
|
+ try:
|
|
|
+ # Add timestamp for tracking
|
|
|
+ token["issued_at"] = datetime.now().timestamp()
|
|
|
+
|
|
|
+ # Calculate expires_at if we have expires_in
|
|
|
+ if "expires_in" in token and "expires_at" not in token:
|
|
|
+ token["expires_at"] = (
|
|
|
+ datetime.now().timestamp() + token["expires_in"]
|
|
|
+ )
|
|
|
+
|
|
|
+ # Clean up any existing sessions for this user/client_id first
|
|
|
+ sessions = OAuthSessions.get_sessions_by_user_id(user_id)
|
|
|
+ for session in sessions:
|
|
|
+ if session.provider == client_id:
|
|
|
+ OAuthSessions.delete_session_by_id(session.id)
|
|
|
+
|
|
|
+ session = OAuthSessions.create_session(
|
|
|
+ user_id=user_id,
|
|
|
+ provider=client_id,
|
|
|
+ token=token,
|
|
|
+ )
|
|
|
+
|
|
|
+ log.info(
|
|
|
+ f"Stored OAuth session server-side for user {user_id}, client_id {client_id}"
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ error_message = "Failed to store OAuth session server-side"
|
|
|
+ log.error(f"Failed to store OAuth session server-side: {e}")
|
|
|
+ else:
|
|
|
+ error_message = "Failed to obtain OAuth token"
|
|
|
+ log.warning(error_message)
|
|
|
+ except Exception as e:
|
|
|
+ error_message = "OAuth callback error"
|
|
|
+ log.warning(f"OAuth callback error: {e}")
|
|
|
+
|
|
|
+ redirect_base_url = (
|
|
|
+ str(request.app.state.config.WEBUI_URL or request.base_url)
|
|
|
+ ).rstrip("/")
|
|
|
+ redirect_url = f"{redirect_base_url}/auth"
|
|
|
+
|
|
|
+ if error_message:
|
|
|
+ redirect_url = f"{redirect_url}?error={error_message}"
|
|
|
+ return RedirectResponse(url=redirect_url, headers=response.headers)
|
|
|
+
|
|
|
+ response = RedirectResponse(url=redirect_url, headers=response.headers)
|
|
|
+
|
|
|
+
|
|
|
class OAuthManager:
|
|
|
def __init__(self, app):
|
|
|
self.oauth = OAuth()
|
|
@@ -792,9 +1250,9 @@ class OAuthManager:
|
|
|
else ERROR_MESSAGES.DEFAULT("Error during OAuth process")
|
|
|
)
|
|
|
|
|
|
- redirect_base_url = str(request.app.state.config.WEBUI_URL or request.base_url)
|
|
|
- if redirect_base_url.endswith("/"):
|
|
|
- redirect_base_url = redirect_base_url[:-1]
|
|
|
+ redirect_base_url = (
|
|
|
+ str(request.app.state.config.WEBUI_URL or request.base_url)
|
|
|
+ ).rstrip("/")
|
|
|
redirect_url = f"{redirect_base_url}/auth"
|
|
|
|
|
|
if error_message:
|