|
@@ -4,9 +4,11 @@ import mimetypes
|
|
|
import sys
|
|
|
import uuid
|
|
|
import json
|
|
|
+from datetime import datetime, timedelta
|
|
|
|
|
|
import re
|
|
|
import fnmatch
|
|
|
+import time
|
|
|
|
|
|
import aiohttp
|
|
|
from authlib.integrations.starlette_client import OAuth
|
|
@@ -17,8 +19,12 @@ from fastapi import (
|
|
|
)
|
|
|
from starlette.responses import RedirectResponse
|
|
|
|
|
|
+
|
|
|
from open_webui.models.auths import Auths
|
|
|
+from open_webui.models.oauth_sessions import OAuthSessions
|
|
|
from open_webui.models.users import Users
|
|
|
+
|
|
|
+
|
|
|
from open_webui.models.groups import Groups, GroupModel, GroupUpdateForm, GroupForm
|
|
|
from open_webui.config import (
|
|
|
DEFAULT_USER_ROLE,
|
|
@@ -49,7 +55,7 @@ from open_webui.env import (
|
|
|
WEBUI_NAME,
|
|
|
WEBUI_AUTH_COOKIE_SAME_SITE,
|
|
|
WEBUI_AUTH_COOKIE_SECURE,
|
|
|
- ENABLE_OAUTH_SESSION_TOKENS_COOKIES,
|
|
|
+ ENABLE_OAUTH_ID_TOKEN_COOKIE,
|
|
|
)
|
|
|
from open_webui.utils.misc import parse_duration
|
|
|
from open_webui.utils.auth import get_password_hash, create_token
|
|
@@ -131,11 +137,187 @@ class OAuthManager:
|
|
|
def __init__(self, app):
|
|
|
self.oauth = OAuth()
|
|
|
self.app = app
|
|
|
+
|
|
|
+ self._clients = {}
|
|
|
for _, provider_config in OAUTH_PROVIDERS.items():
|
|
|
provider_config["register"](self.oauth)
|
|
|
|
|
|
def get_client(self, provider_name):
|
|
|
- return self.oauth.create_client(provider_name)
|
|
|
+ if provider_name not in self._clients:
|
|
|
+ self._clients[provider_name] = self.oauth.create_client(provider_name)
|
|
|
+ return self._clients[provider_name]
|
|
|
+
|
|
|
+ def get_server_metadata_url(self, provider_name):
|
|
|
+ if provider_name in self._clients:
|
|
|
+ client = self._clients[provider_name]
|
|
|
+ return (
|
|
|
+ client.server_metadata_url
|
|
|
+ if hasattr(client, "server_metadata_url")
|
|
|
+ else None
|
|
|
+ )
|
|
|
+ return None
|
|
|
+
|
|
|
+ def get_oauth_token(
|
|
|
+ self, user_id: str, session_id: str, force_refresh: bool = False
|
|
|
+ ):
|
|
|
+ """
|
|
|
+ Get a valid OAuth token for the user, automatically refreshing if needed.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ user_id: The user ID
|
|
|
+ provider: Optional provider name. If None, gets the most recent session.
|
|
|
+ force_refresh: Force token refresh even if current token appears valid
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ dict: OAuth token data with access_token, or None if no valid token available
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # Get the OAuth session
|
|
|
+ session = OAuthSessions.get_session_by_id_and_user_id(session_id, user_id)
|
|
|
+ if not session:
|
|
|
+ log.warning(
|
|
|
+ f"No OAuth session found for user {user_id}, session {session_id}"
|
|
|
+ )
|
|
|
+ return None
|
|
|
+
|
|
|
+ if force_refresh or datetime.now() + timedelta(
|
|
|
+ minutes=5
|
|
|
+ ) >= datetime.fromtimestamp(session.expires_at):
|
|
|
+ log.debug(
|
|
|
+ f"Token refresh needed for user {user_id}, provider {session.provider}"
|
|
|
+ )
|
|
|
+ refreshed_token = self._refresh_token(session)
|
|
|
+ if refreshed_token:
|
|
|
+ return refreshed_token
|
|
|
+ else:
|
|
|
+ log.warning(
|
|
|
+ f"Token refresh failed for user {user_id}, provider {session.provider}"
|
|
|
+ )
|
|
|
+ return None
|
|
|
+ return session.token
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ log.error(f"Error getting OAuth token for user {user_id}: {e}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ async def _refresh_token(self, session) -> dict:
|
|
|
+ """
|
|
|
+ Refresh an OAuth token if needed, with concurrency protection.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ session: The OAuth session object
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ dict: Refreshed token data, or None if refresh failed
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # Perform the actual refresh
|
|
|
+ refreshed_token = await self._perform_token_refresh(session)
|
|
|
+
|
|
|
+ if refreshed_token:
|
|
|
+ # Update the session with new token data
|
|
|
+ session = OAuthSessions.update_session_by_id(
|
|
|
+ session.id, refreshed_token
|
|
|
+ )
|
|
|
+ log.info(f"Successfully refreshed token for session {session.id}")
|
|
|
+ return session.token
|
|
|
+ else:
|
|
|
+ log.error(f"Failed to refresh token for session {session.id}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ log.error(f"Error refreshing token for session {session.id}: {e}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ async def _perform_token_refresh(self, session) -> dict:
|
|
|
+ """
|
|
|
+ Perform the actual OAuth token refresh.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ session: The OAuth session object
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ dict: New token data, or None if refresh failed
|
|
|
+ """
|
|
|
+ provider = session.provider
|
|
|
+ token_data = session.token
|
|
|
+
|
|
|
+ if not token_data.get("refresh_token"):
|
|
|
+ log.warning(f"No refresh token available for session {session.id}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ try:
|
|
|
+ client = self.get_client(provider)
|
|
|
+ if not client:
|
|
|
+ log.error(f"No OAuth client found for provider {provider}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ token_endpoint = None
|
|
|
+ async with aiohttp.ClientSession(trust_env=True) as session_http:
|
|
|
+ async with session_http.get(client.gserver_metadata_url) as r:
|
|
|
+ if r.status == 200:
|
|
|
+ openid_data = await r.json()
|
|
|
+ token_endpoint = openid_data.get("token_endpoint")
|
|
|
+ else:
|
|
|
+ log.error(
|
|
|
+ f"Failed to fetch OpenID configuration for provider {provider}"
|
|
|
+ )
|
|
|
+ if not token_endpoint:
|
|
|
+ log.error(f"No token endpoint found for provider {provider}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ # Prepare refresh request
|
|
|
+ refresh_data = {
|
|
|
+ "grant_type": "refresh_token",
|
|
|
+ "refresh_token": token_data["refresh_token"],
|
|
|
+ "client_id": client.client_id,
|
|
|
+ }
|
|
|
+ # Add client_secret if available (some providers require it)
|
|
|
+ if hasattr(client, "client_secret") and client.client_secret:
|
|
|
+ refresh_data["client_secret"] = client.client_secret
|
|
|
+
|
|
|
+ # Make refresh request
|
|
|
+ async with aiohttp.ClientSession(trust_env=True) as session_http:
|
|
|
+ async with session_http.post(
|
|
|
+ token_endpoint,
|
|
|
+ data=refresh_data,
|
|
|
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|
|
+ ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
|
|
+ ) as r:
|
|
|
+ if r.status == 200:
|
|
|
+ new_token_data = await r.json()
|
|
|
+
|
|
|
+ # Merge with existing token data (preserve refresh_token if not provided)
|
|
|
+ if "refresh_token" not in new_token_data:
|
|
|
+ new_token_data["refresh_token"] = token_data[
|
|
|
+ "refresh_token"
|
|
|
+ ]
|
|
|
+
|
|
|
+ # Add timestamp for tracking
|
|
|
+ new_token_data["issued_at"] = datetime.now().timestamp()
|
|
|
+
|
|
|
+ # Calculate expires_at if we have expires_in
|
|
|
+ if (
|
|
|
+ "expires_in" in new_token_data
|
|
|
+ and "expires_at" not in new_token_data
|
|
|
+ ):
|
|
|
+ new_token_data["expires_at"] = (
|
|
|
+ datetime.now().timestamp()
|
|
|
+ + new_token_data["expires_in"]
|
|
|
+ )
|
|
|
+
|
|
|
+ log.debug(f"Token refresh successful for provider {provider}")
|
|
|
+ return new_token_data
|
|
|
+ else:
|
|
|
+ error_text = await r.text()
|
|
|
+ log.error(
|
|
|
+ f"Token refresh failed for provider {provider}: {r.status} - {error_text}"
|
|
|
+ )
|
|
|
+ return None
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ log.error(f"Exception during token refresh for provider {provider}: {e}")
|
|
|
+ return None
|
|
|
|
|
|
def get_user_role(self, user, user_data):
|
|
|
user_count = Users.get_num_users()
|
|
@@ -624,33 +806,42 @@ class OAuthManager:
|
|
|
secure=WEBUI_AUTH_COOKIE_SECURE,
|
|
|
)
|
|
|
|
|
|
- if ENABLE_OAUTH_SIGNUP.value:
|
|
|
- if ENABLE_OAUTH_SESSION_TOKENS_COOKIES:
|
|
|
- oauth_id_token = token.get("id_token")
|
|
|
- response.set_cookie(
|
|
|
- key="oauth_id_token",
|
|
|
- value=oauth_id_token,
|
|
|
- httponly=True,
|
|
|
- samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
|
|
- secure=WEBUI_AUTH_COOKIE_SECURE,
|
|
|
- )
|
|
|
+ # Legacy cookies for compatibility with older frontend versions
|
|
|
+ if ENABLE_OAUTH_ID_TOKEN_COOKIE:
|
|
|
+ response.set_cookie(
|
|
|
+ key="oauth_id_token",
|
|
|
+ value=token.get("id_token"),
|
|
|
+ httponly=True,
|
|
|
+ samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
|
|
+ secure=WEBUI_AUTH_COOKIE_SECURE,
|
|
|
+ )
|
|
|
|
|
|
- oauth_access_token = token.get("access_token")
|
|
|
- response.set_cookie(
|
|
|
- key="oauth_access_token",
|
|
|
- value=oauth_access_token,
|
|
|
- httponly=True,
|
|
|
- samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
|
|
- secure=WEBUI_AUTH_COOKIE_SECURE,
|
|
|
- )
|
|
|
+ try:
|
|
|
+ # Add timestamp for tracking
|
|
|
+ token["issued_at"] = datetime.now().timestamp()
|
|
|
|
|
|
- oauth_refresh_token = token.get("refresh_token")
|
|
|
- response.set_cookie(
|
|
|
- key="oauth_refresh_token",
|
|
|
- value=oauth_refresh_token,
|
|
|
- httponly=True,
|
|
|
- samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
|
|
- secure=WEBUI_AUTH_COOKIE_SECURE,
|
|
|
- )
|
|
|
+ # Calculate expires_at if we have expires_in
|
|
|
+ if "expires_in" in token and "expires_at" not in token:
|
|
|
+ token["expires_at"] = datetime.now().timestamp() + token["expires_in"]
|
|
|
+
|
|
|
+ session_id = await OAuthSessions.create_session(
|
|
|
+ user_id=user.id,
|
|
|
+ provider=provider,
|
|
|
+ token=token,
|
|
|
+ )
|
|
|
+
|
|
|
+ response.set_cookie(
|
|
|
+ key="oauth_session_id",
|
|
|
+ value=session_id,
|
|
|
+ httponly=True,
|
|
|
+ samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
|
|
+ secure=WEBUI_AUTH_COOKIE_SECURE,
|
|
|
+ )
|
|
|
+
|
|
|
+ log.info(
|
|
|
+ f"Stored OAuth session server-side for user {user.id}, provider {provider}"
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ log.error(f"Failed to store OAuth session server-side: {e}")
|
|
|
|
|
|
return response
|