浏览代码

feat: server-side OAuth token management system

Co-Authored-By: Classic298 <27028174+Classic298@users.noreply.github.com>
Timothy Jaeryang Baek 1 月之前
父节点
当前提交
217f4daef0

+ 11 - 2
backend/open_webui/env.py

@@ -465,8 +465,17 @@ ENABLE_COMPRESSION_MIDDLEWARE = (
     os.environ.get("ENABLE_COMPRESSION_MIDDLEWARE", "True").lower() == "true"
 )
 
-ENABLE_OAUTH_SESSION_TOKENS_COOKIES = (
-    os.environ.get("ENABLE_OAUTH_SESSION_TOKENS_COOKIES", "True").lower() == "true"
+####################################
+# OAUTH Configuration
+####################################
+
+
+ENABLE_OAUTH_ID_TOKEN_COOKIE = (
+    os.environ.get("ENABLE_OAUTH_ID_TOKEN_COOKIE", "True").lower() == "true"
+)
+
+OAUTH_SESSION_TOKEN_ENCRYPTION_KEY = os.environ.get(
+    "OAUTH_SESSION_TOKEN_ENCRYPTION_KEY", WEBUI_SECRET_KEY
 )
 
 

+ 1 - 0
backend/open_webui/main.py

@@ -592,6 +592,7 @@ app = FastAPI(
 )
 
 oauth_manager = OAuthManager(app)
+app.state.oauth_manager = oauth_manager
 
 app.state.instance_id = None
 app.state.config = AppConfig(

+ 52 - 0
backend/open_webui/migrations/versions/38d63c18f30f_add_oauth_session_table.py

@@ -0,0 +1,52 @@
+"""Add oauth_session table
+
+Revision ID: 38d63c18f30f
+Revises: 3af16a1c9fb6
+Create Date: 2025-09-08 14:19:59.583921
+
+"""
+
+from typing import Sequence, Union
+
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision: str = "38d63c18f30f"
+down_revision: Union[str, None] = "3af16a1c9fb6"
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def upgrade() -> None:
+    # Create oauth_session table
+    op.create_table(
+        "oauth_session",
+        sa.Column("id", sa.Text(), nullable=False),
+        sa.Column("user_id", sa.Text(), nullable=False),
+        sa.Column("provider", sa.Text(), nullable=False),
+        sa.Column("token", sa.Text(), nullable=False),
+        sa.Column("expires_at", sa.BigInteger(), nullable=False),
+        sa.Column("created_at", sa.BigInteger(), nullable=False),
+        sa.Column("updated_at", sa.BigInteger(), nullable=False),
+        sa.PrimaryKeyConstraint("id"),
+        sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
+    )
+
+    # Create indexes for better performance
+    op.create_index("idx_oauth_session_user_id", "oauth_session", ["user_id"])
+    op.create_index("idx_oauth_session_expires_at", "oauth_session", ["expires_at"])
+    op.create_index(
+        "idx_oauth_session_user_provider", "oauth_session", ["user_id", "provider"]
+    )
+
+
+def downgrade() -> None:
+    # Drop indexes first
+    op.drop_index("idx_oauth_session_user_provider", table_name="oauth_session")
+    op.drop_index("idx_oauth_session_expires_at", table_name="oauth_session")
+    op.drop_index("idx_oauth_session_user_id", table_name="oauth_session")
+
+    # Drop the table
+    op.drop_table("oauth_session")

+ 247 - 0
backend/open_webui/models/oauth_sessions.py

@@ -0,0 +1,247 @@
+import time
+import logging
+import uuid
+from typing import Optional, List
+import base64
+import hashlib
+import json
+
+from cryptography.fernet import Fernet
+
+from open_webui.internal.db import Base, get_db
+from open_webui.env import SRC_LOG_LEVELS, OAUTH_SESSION_TOKEN_ENCRYPTION_KEY
+
+from pydantic import BaseModel, ConfigDict
+from sqlalchemy import BigInteger, Column, String, Text, Index
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["MODELS"])
+
+####################
+# DB MODEL
+####################
+
+
+class OAuthSession(Base):
+    __tablename__ = "oauth_session"
+
+    id = Column(Text, primary_key=True)
+    user_id = Column(Text, nullable=False)
+    provider = Column(Text, nullable=False)
+    token = Column(
+        Text, nullable=False
+    )  # JSON with access_token, id_token, refresh_token
+    expires_at = Column(BigInteger, nullable=False)
+    created_at = Column(BigInteger, nullable=False)
+    updated_at = Column(BigInteger, nullable=False)
+
+    # Add indexes for better performance
+    __table_args__ = (
+        Index("idx_oauth_session_user_id", "user_id"),
+        Index("idx_oauth_session_expires_at", "expires_at"),
+        Index("idx_oauth_session_user_provider", "user_id", "provider"),
+    )
+
+
+class OAuthSessionModel(BaseModel):
+    id: str
+    user_id: str
+    provider: str
+    token: dict
+    expires_at: int  # timestamp in epoch
+    created_at: int  # timestamp in epoch
+    updated_at: int  # timestamp in epoch
+
+    model_config = ConfigDict(from_attributes=True)
+
+
+####################
+# Forms
+####################
+
+
+class OAuthSessionResponse(BaseModel):
+    id: str
+    user_id: str
+    provider: str
+    expires_at: int
+
+
+class OAuthSessionTable:
+    def __init__(self):
+        self.encryption_key = OAUTH_SESSION_TOKEN_ENCRYPTION_KEY
+        if not self.encryption_key:
+            raise Exception("OAUTH_SESSION_TOKEN_ENCRYPTION_KEY is not set")
+
+        # check if encryption key is in the right format for Fernet (32 url-safe base64-encoded bytes)
+        if len(self.encryption_key) != 44:
+            key_bytes = hashlib.sha256(self.encryption_key.encode()).digest()
+            self.encryption_key = base64.urlsafe_b64encode(key_bytes)
+        else:
+            self.encryption_key = self.encryption_key.encode()
+
+        try:
+            self.fernet = Fernet(self.encryption_key)
+        except Exception as e:
+            log.error(f"Error initializing Fernet with provided key: {e}")
+            raise
+
+    def _encrypt_token(self, token) -> str:
+        """Encrypt OAuth tokens for storage"""
+        try:
+            token_json = json.dumps(token)
+            encrypted = self.fernet.encrypt(token_json.encode()).decode()
+            return encrypted
+        except Exception as e:
+            log.error(f"Error encrypting tokens: {e}")
+            raise
+
+    def _decrypt_token(self, token: str):
+        """Decrypt OAuth tokens from storage"""
+        try:
+            decrypted = self.fernet.decrypt(token.encode()).decode()
+            return json.loads(decrypted)
+        except Exception as e:
+            log.error(f"Error decrypting tokens: {e}")
+            raise
+
+    def create_session(
+        self,
+        user_id: str,
+        provider: str,
+        token: dict,
+    ) -> Optional[OAuthSessionModel]:
+        """Create a new OAuth session"""
+        try:
+            with get_db() as db:
+                current_time = int(time.time())
+                id = str(uuid.uuid4())
+
+                result = OAuthSession(
+                    **{
+                        "id": id,
+                        "user_id": user_id,
+                        "provider": provider,
+                        "token": self._encrypt_token(token),
+                        "expires_at": token.get("expires_at"),
+                        "created_at": current_time,
+                        "updated_at": current_time,
+                    }
+                )
+
+                db.add(result)
+                db.commit()
+                db.refresh(result)
+
+                if result:
+                    result.token = token  # Return decrypted token
+                    return OAuthSessionModel.model_validate(result)
+                else:
+                    return None
+        except Exception as e:
+            log.error(f"Error creating OAuth session: {e}")
+            return None
+
+    def get_session_by_id(self, session_id: str) -> Optional[OAuthSessionModel]:
+        """Get OAuth session by ID"""
+        try:
+            with get_db() as db:
+                session = db.query(OAuthSession).filter_by(id=session_id).first()
+                if session:
+                    session.token = self._decrypt_token(session.token)
+                    return OAuthSessionModel.model_validate(session)
+                    
+                return None
+        except Exception as e:
+            log.error(f"Error getting OAuth session by ID: {e}")
+            return None
+
+    def get_session_by_id_and_user_id(
+        self, session_id: str, user_id: str
+    ) -> Optional[OAuthSessionModel]:
+        """Get OAuth session by ID and user ID"""
+        try:
+            with get_db() as db:
+                session = (
+                    db.query(OAuthSession)
+                    .filter_by(id=session_id, user_id=user_id)
+                    .first()
+                )
+                if session:
+                    session.token = self._decrypt_token(session.token)
+                    return OAuthSessionModel.model_validate(session)
+                    )
+                return None
+        except Exception as e:
+            log.error(f"Error getting OAuth session by ID: {e}")
+            return None
+
+    def get_sessions_by_user_id(self, user_id: str) -> List[OAuthSessionModel]:
+        """Get all OAuth sessions for a user"""
+        try:
+            with get_db() as db:
+                sessions = db.query(OAuthSession).filter_by(user_id=user_id).all()
+
+
+                results = []
+                for session in sessions:
+                    session.token = self._decrypt_token(session.token)
+                    results.append(OAuthSessionModel.model_validate(session))
+
+                return results
+            
+        except Exception as e:
+            log.error(f"Error getting OAuth sessions by user ID: {e}")
+            return []
+
+    def update_session_by_id(
+        self, session_id: str, token: dict
+    ) -> Optional[OAuthSessionModel]:
+        """Update OAuth session tokens"""
+        try:
+            with get_db() as db:
+                current_time = int(time.time())
+
+                db.query(OAuthSession).filter_by(id=session_id).update(
+                    {
+                        "token": self._encrypt_token(token),
+                        "expires_at": token.get("expires_at"),
+                        "updated_at": current_time,
+                    }
+                )
+                db.commit()
+                session = db.query(OAuthSession).filter_by(id=session_id).first()
+
+                if session:
+                    session.token = self._decrypt_token(session.token)
+                    return OAuthSessionModel.model_validate(session)
+                    
+                return None
+        except Exception as e:
+            log.error(f"Error updating OAuth session tokens: {e}")
+            return None
+
+    def delete_session_by_id(self, session_id: str) -> bool:
+        """Delete an OAuth session"""
+        try:
+            with get_db() as db:
+                result = db.query(OAuthSession).filter_by(id=session_id).delete()
+                db.commit()
+                return result > 0
+        except Exception as e:
+            log.error(f"Error deleting OAuth session: {e}")
+            return False
+
+    def delete_sessions_by_user_id(self, user_id: str) -> bool:
+        """Delete all OAuth sessions for a user"""
+        try:
+            with get_db() as db:
+                result = db.query(OAuthSession).filter_by(user_id=user_id).delete()
+                db.commit()
+                return True
+        except Exception as e:
+            log.error(f"Error deleting OAuth sessions by user ID: {e}")
+            return False
+
+
+OAuthSessions = OAuthSessionTable()

+ 17 - 15
backend/open_webui/routers/auths.py

@@ -19,6 +19,7 @@ from open_webui.models.auths import (
 )
 from open_webui.models.users import Users, UpdateProfileForm
 from open_webui.models.groups import Groups
+from open_webui.models.oauth_sessions import OAuthSessions
 
 from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
 from open_webui.env import (
@@ -28,7 +29,6 @@ from open_webui.env import (
     WEBUI_AUTH_TRUSTED_GROUPS_HEADER,
     WEBUI_AUTH_COOKIE_SAME_SITE,
     WEBUI_AUTH_COOKIE_SECURE,
-    ENABLE_OAUTH_SESSION_TOKENS_COOKIES,
     WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
     ENABLE_INITIAL_ADMIN_SIGNUP,
     SRC_LOG_LEVELS,
@@ -678,24 +678,27 @@ async def signout(request: Request, response: Response):
     response.delete_cookie("token")
     response.delete_cookie("oui-session")
 
-    if ENABLE_OAUTH_SIGNUP.value:
-        # TODO: update this to use oauth_session_tokens in User Object
-        oauth_id_token = request.cookies.get("oauth_id_token")
+    oauth_session_id = request.cookies.get("oauth_session_id")
+    if oauth_session_id:
+        response.delete_cookie("oauth_session_id")
 
-        if oauth_id_token and OPENID_PROVIDER_URL.value:
+        session = OAuthSessions.get_session_by_id(oauth_session_id)
+        oauth_server_metadata_url = (
+            request.app.state.oauth_manager.get_server_metadata_url(session.provider)
+            if session
+            else None
+        ) or OPENID_PROVIDER_URL.value
+
+        if session and oauth_server_metadata_url:
+            oauth_id_token = session.token.get("id_token")
             try:
                 async with ClientSession(trust_env=True) as session:
-                    async with session.get(OPENID_PROVIDER_URL.value) as r:
+                    async with session.get(oauth_server_metadata_url) as r:
                         if r.status == 200:
                             openid_data = await r.json()
                             logout_url = openid_data.get("end_session_endpoint")
 
                             if logout_url:
-                                if ENABLE_OAUTH_SESSION_TOKENS_COOKIES:
-                                    response.delete_cookie("oauth_id_token")
-                                    response.delete_cookie("oauth_access_token")
-                                    response.delete_cookie("oauth_refresh_token")
-
                                 return JSONResponse(
                                     status_code=200,
                                     content={
@@ -710,15 +713,14 @@ async def signout(request: Request, response: Response):
                                     headers=response.headers,
                                 )
                         else:
-                            raise HTTPException(
-                                status_code=r.status,
-                                detail="Failed to fetch OpenID configuration",
-                            )
+                            raise Exception("Failed to fetch OpenID configuration")
+
             except Exception as e:
                 log.error(f"OpenID signout error: {str(e)}")
                 raise HTTPException(
                     status_code=500,
                     detail="Failed to sign out from the OpenID provider.",
+                    headers=response.headers,
                 )
 
     if WEBUI_AUTH_SIGNOUT_REDIRECT_URL:

+ 52 - 50
backend/open_webui/utils/auth.py

@@ -261,61 +261,63 @@ def get_current_user(
         return user
 
     # auth by jwt token
-    try:
-        data = decode_token(token)
-    except Exception as e:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail="Invalid token",
-        )
 
-    if data is not None and "id" in data:
-        user = Users.get_user_by_id(data["id"])
-        if user is None:
+    try:
+        try:
+            data = decode_token(token)
+        except Exception as e:
             raise HTTPException(
                 status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.INVALID_TOKEN,
+                detail="Invalid token",
             )
-        else:
-            if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
-                trusted_email = request.headers.get(
-                    WEBUI_AUTH_TRUSTED_EMAIL_HEADER, ""
-                ).lower()
-                if trusted_email and user.email != trusted_email:
-                    # Delete the token cookie
-                    response.delete_cookie("token")
-                    # Delete OAuth token if present
-
-                    if request.cookies.get("oauth_id_token"):
-                        response.delete_cookie("oauth_id_token")
-                    if request.cookies.get("oauth_access_token"):
-                        response.delete_cookie("oauth_access_token")
-                    if request.cookies.get("oauth_refresh_token"):
-                        response.delete_cookie("oauth_refresh_token")
-
-                    raise HTTPException(
-                        status_code=status.HTTP_401_UNAUTHORIZED,
-                        detail="User mismatch. Please sign in again.",
-                    )
 
-            # Add user info to current span
-            current_span = trace.get_current_span()
-            if current_span:
-                current_span.set_attribute("client.user.id", user.id)
-                current_span.set_attribute("client.user.email", user.email)
-                current_span.set_attribute("client.user.role", user.role)
-                current_span.set_attribute("client.auth.type", "jwt")
-
-            # Refresh the user's last active timestamp asynchronously
-            # to prevent blocking the request
-            if background_tasks:
-                background_tasks.add_task(Users.update_user_last_active_by_id, user.id)
-        return user
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.UNAUTHORIZED,
-        )
+        if data is not None and "id" in data:
+            user = Users.get_user_by_id(data["id"])
+            if user is None:
+                raise HTTPException(
+                    status_code=status.HTTP_401_UNAUTHORIZED,
+                    detail=ERROR_MESSAGES.INVALID_TOKEN,
+                )
+            else:
+                if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
+                    trusted_email = request.headers.get(
+                        WEBUI_AUTH_TRUSTED_EMAIL_HEADER, ""
+                    ).lower()
+                    if trusted_email and user.email != trusted_email:
+                        raise HTTPException(
+                            status_code=status.HTTP_401_UNAUTHORIZED,
+                            detail="User mismatch. Please sign in again.",
+                        )
+
+                # Add user info to current span
+                current_span = trace.get_current_span()
+                if current_span:
+                    current_span.set_attribute("client.user.id", user.id)
+                    current_span.set_attribute("client.user.email", user.email)
+                    current_span.set_attribute("client.user.role", user.role)
+                    current_span.set_attribute("client.auth.type", "jwt")
+
+                # Refresh the user's last active timestamp asynchronously
+                # to prevent blocking the request
+                if background_tasks:
+                    background_tasks.add_task(
+                        Users.update_user_last_active_by_id, user.id
+                    )
+            return user
+        else:
+            raise HTTPException(
+                status_code=status.HTTP_401_UNAUTHORIZED,
+                detail=ERROR_MESSAGES.UNAUTHORIZED,
+            )
+    except Exception as e:
+        # Delete the token cookie
+        if request.cookies.get("token"):
+            response.delete_cookie("token")
+        # Delete OAuth session if present
+        if request.cookies.get("oauth_session_id"):
+            response.delete_cookie("oauth_session_id")
+
+        raise e
 
 
 def get_current_user_by_api_key(api_key: str):

+ 9 - 0
backend/open_webui/utils/middleware.py

@@ -815,6 +815,14 @@ async def process_chat_payload(request, form_data, user, metadata, model):
     event_emitter = get_event_emitter(metadata)
     event_call = get_event_call(metadata)
 
+    oauth_token = None
+    try:
+        oauth_token = await request.app.state.oauth_manager.get_oauth_token(
+            user.id, request.cookies.get("oauth_session_id", None)
+        )
+    except Exception as e:
+        log.error(f"Error getting OAuth token: {e}")
+
     extra_params = {
         "__event_emitter__": event_emitter,
         "__event_call__": event_call,
@@ -822,6 +830,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
         "__metadata__": metadata,
         "__request__": request,
         "__model__": model,
+        "__oauth_token__": oauth_token,
     }
 
     # Initialize events to store additional event to be sent to the client

+ 219 - 28
backend/open_webui/utils/oauth.py

@@ -4,9 +4,11 @@ import mimetypes
 import sys
 import uuid
 import json
+from datetime import datetime, timedelta
 
 import re
 import fnmatch
+import time
 
 import aiohttp
 from authlib.integrations.starlette_client import OAuth
@@ -17,8 +19,12 @@ from fastapi import (
 )
 from starlette.responses import RedirectResponse
 
+
 from open_webui.models.auths import Auths
+from open_webui.models.oauth_sessions import OAuthSessions
 from open_webui.models.users import Users
+
+
 from open_webui.models.groups import Groups, GroupModel, GroupUpdateForm, GroupForm
 from open_webui.config import (
     DEFAULT_USER_ROLE,
@@ -49,7 +55,7 @@ from open_webui.env import (
     WEBUI_NAME,
     WEBUI_AUTH_COOKIE_SAME_SITE,
     WEBUI_AUTH_COOKIE_SECURE,
-    ENABLE_OAUTH_SESSION_TOKENS_COOKIES,
+    ENABLE_OAUTH_ID_TOKEN_COOKIE,
 )
 from open_webui.utils.misc import parse_duration
 from open_webui.utils.auth import get_password_hash, create_token
@@ -131,11 +137,187 @@ class OAuthManager:
     def __init__(self, app):
         self.oauth = OAuth()
         self.app = app
+
+        self._clients = {}
         for _, provider_config in OAUTH_PROVIDERS.items():
             provider_config["register"](self.oauth)
 
     def get_client(self, provider_name):
-        return self.oauth.create_client(provider_name)
+        if provider_name not in self._clients:
+            self._clients[provider_name] = self.oauth.create_client(provider_name)
+        return self._clients[provider_name]
+
+    def get_server_metadata_url(self, provider_name):
+        if provider_name in self._clients:
+            client = self._clients[provider_name]
+            return (
+                client.server_metadata_url
+                if hasattr(client, "server_metadata_url")
+                else None
+            )
+        return None
+
+    def get_oauth_token(
+        self, user_id: str, session_id: str, force_refresh: bool = False
+    ):
+        """
+        Get a valid OAuth token for the user, automatically refreshing if needed.
+
+        Args:
+            user_id: The user ID
+            provider: Optional provider name. If None, gets the most recent session.
+            force_refresh: Force token refresh even if current token appears valid
+
+        Returns:
+            dict: OAuth token data with access_token, or None if no valid token available
+        """
+        try:
+            # Get the OAuth session
+            session = OAuthSessions.get_session_by_id_and_user_id(session_id, user_id)
+            if not session:
+                log.warning(
+                    f"No OAuth session found for user {user_id}, session {session_id}"
+                )
+                return None
+
+            if force_refresh or datetime.now() + timedelta(
+                minutes=5
+            ) >= datetime.fromtimestamp(session.expires_at):
+                log.debug(
+                    f"Token refresh needed for user {user_id}, provider {session.provider}"
+                )
+                refreshed_token = self._refresh_token(session)
+                if refreshed_token:
+                    return refreshed_token
+                else:
+                    log.warning(
+                        f"Token refresh failed for user {user_id}, provider {session.provider}"
+                    )
+                    return None
+            return session.token
+
+        except Exception as e:
+            log.error(f"Error getting OAuth token for user {user_id}: {e}")
+            return None
+
+    async def _refresh_token(self, session) -> dict:
+        """
+        Refresh an OAuth token if needed, with concurrency protection.
+
+        Args:
+            session: The OAuth session object
+
+        Returns:
+            dict: Refreshed token data, or None if refresh failed
+        """
+        try:
+            # Perform the actual refresh
+            refreshed_token = await self._perform_token_refresh(session)
+
+            if refreshed_token:
+                # Update the session with new token data
+                session = OAuthSessions.update_session_by_id(
+                    session.id, refreshed_token
+                )
+                log.info(f"Successfully refreshed token for session {session.id}")
+                return session.token
+            else:
+                log.error(f"Failed to refresh token for session {session.id}")
+                return None
+
+        except Exception as e:
+            log.error(f"Error refreshing token for session {session.id}: {e}")
+            return None
+
+    async def _perform_token_refresh(self, session) -> dict:
+        """
+        Perform the actual OAuth token refresh.
+
+        Args:
+            session: The OAuth session object
+
+        Returns:
+            dict: New token data, or None if refresh failed
+        """
+        provider = session.provider
+        token_data = session.token
+
+        if not token_data.get("refresh_token"):
+            log.warning(f"No refresh token available for session {session.id}")
+            return None
+
+        try:
+            client = self.get_client(provider)
+            if not client:
+                log.error(f"No OAuth client found for provider {provider}")
+                return None
+
+            token_endpoint = None
+            async with aiohttp.ClientSession(trust_env=True) as session_http:
+                async with session_http.get(client.gserver_metadata_url) as r:
+                    if r.status == 200:
+                        openid_data = await r.json()
+                        token_endpoint = openid_data.get("token_endpoint")
+                    else:
+                        log.error(
+                            f"Failed to fetch OpenID configuration for provider {provider}"
+                        )
+            if not token_endpoint:
+                log.error(f"No token endpoint found for provider {provider}")
+                return None
+
+            # Prepare refresh request
+            refresh_data = {
+                "grant_type": "refresh_token",
+                "refresh_token": token_data["refresh_token"],
+                "client_id": client.client_id,
+            }
+            # Add client_secret if available (some providers require it)
+            if hasattr(client, "client_secret") and client.client_secret:
+                refresh_data["client_secret"] = client.client_secret
+
+            # Make refresh request
+            async with aiohttp.ClientSession(trust_env=True) as session_http:
+                async with session_http.post(
+                    token_endpoint,
+                    data=refresh_data,
+                    headers={"Content-Type": "application/x-www-form-urlencoded"},
+                    ssl=AIOHTTP_CLIENT_SESSION_SSL,
+                ) as r:
+                    if r.status == 200:
+                        new_token_data = await r.json()
+
+                        # Merge with existing token data (preserve refresh_token if not provided)
+                        if "refresh_token" not in new_token_data:
+                            new_token_data["refresh_token"] = token_data[
+                                "refresh_token"
+                            ]
+
+                        # Add timestamp for tracking
+                        new_token_data["issued_at"] = datetime.now().timestamp()
+
+                        # Calculate expires_at if we have expires_in
+                        if (
+                            "expires_in" in new_token_data
+                            and "expires_at" not in new_token_data
+                        ):
+                            new_token_data["expires_at"] = (
+                                datetime.now().timestamp()
+                                + new_token_data["expires_in"]
+                            )
+
+                        log.debug(f"Token refresh successful for provider {provider}")
+                        return new_token_data
+                    else:
+                        error_text = await r.text()
+                        log.error(
+                            f"Token refresh failed for provider {provider}: {r.status} - {error_text}"
+                        )
+                        return None
+
+        except Exception as e:
+            log.error(f"Exception during token refresh for provider {provider}: {e}")
+            return None
 
     def get_user_role(self, user, user_data):
         user_count = Users.get_num_users()
@@ -624,33 +806,42 @@ class OAuthManager:
             secure=WEBUI_AUTH_COOKIE_SECURE,
         )
 
-        if ENABLE_OAUTH_SIGNUP.value:
-            if ENABLE_OAUTH_SESSION_TOKENS_COOKIES:
-                oauth_id_token = token.get("id_token")
-                response.set_cookie(
-                    key="oauth_id_token",
-                    value=oauth_id_token,
-                    httponly=True,
-                    samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
-                    secure=WEBUI_AUTH_COOKIE_SECURE,
-                )
+        # Legacy cookies for compatibility with older frontend versions
+        if ENABLE_OAUTH_ID_TOKEN_COOKIE:
+            response.set_cookie(
+                key="oauth_id_token",
+                value=token.get("id_token"),
+                httponly=True,
+                samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
+                secure=WEBUI_AUTH_COOKIE_SECURE,
+            )
 
-                oauth_access_token = token.get("access_token")
-                response.set_cookie(
-                    key="oauth_access_token",
-                    value=oauth_access_token,
-                    httponly=True,
-                    samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
-                    secure=WEBUI_AUTH_COOKIE_SECURE,
-                )
+        try:
+            # Add timestamp for tracking
+            token["issued_at"] = datetime.now().timestamp()
 
-                oauth_refresh_token = token.get("refresh_token")
-                response.set_cookie(
-                    key="oauth_refresh_token",
-                    value=oauth_refresh_token,
-                    httponly=True,
-                    samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
-                    secure=WEBUI_AUTH_COOKIE_SECURE,
-                )
+            # Calculate expires_at if we have expires_in
+            if "expires_in" in token and "expires_at" not in token:
+                token["expires_at"] = datetime.now().timestamp() + token["expires_in"]
+
+            session_id = await OAuthSessions.create_session(
+                user_id=user.id,
+                provider=provider,
+                token=token,
+            )
+
+            response.set_cookie(
+                key="oauth_session_id",
+                value=session_id,
+                httponly=True,
+                samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
+                secure=WEBUI_AUTH_COOKIE_SECURE,
+            )
+
+            log.info(
+                f"Stored OAuth session server-side for user {user.id}, provider {provider}"
+            )
+        except Exception as e:
+            log.error(f"Failed to store OAuth session server-side: {e}")
 
         return response

+ 15 - 0
backend/open_webui/utils/tools.py

@@ -129,6 +129,21 @@ async def get_tools(
                         headers["Authorization"] = (
                             f"Bearer {request.state.token.credentials}"
                         )
+                    elif auth_type == "oauth":
+                        oauth_token = None
+                        try:
+                            oauth_token = (
+                                await request.app.state.oauth_manager.get_oauth_token(
+                                    user.id,
+                                    request.cookies.get("oauth_session_id", None),
+                                )
+                            )
+                        except Exception as e:
+                            log.error(f"Error getting OAuth token: {e}")
+
+                        headers["Authorization"] = (
+                            f"Bearer {oauth_token.get('access_token', '')}"
+                        )
                     elif auth_type == "request_headers":
                         headers.update(dict(request.headers))
 

+ 7 - 0
src/lib/components/AddServerModal.svelte

@@ -287,6 +287,7 @@
 											<option value="session">{$i18n.t('Session')}</option>
 
 											{#if !direct}
+												<option value="oauth">{$i18n.t('OAuth')}</option>
 												<option value="request_headers">{$i18n.t('Request Headers')}</option>
 											{/if}
 										</select>
@@ -305,6 +306,12 @@
 											>
 												{$i18n.t('Forwards system user session credentials to authenticate')}
 											</div>
+										{:else if auth_type === 'oauth'}
+											<div
+												class={`text-xs self-center translate-y-[1px] ${($settings?.highContrastMode ?? false) ? 'text-gray-800 dark:text-gray-100' : 'text-gray-500'}`}
+											>
+												{$i18n.t('Forwards user OAuth access token to authenticate')}
+											</div>
 										{:else if auth_type === 'request_headers'}
 											<div
 												class={`text-xs self-center translate-y-[1px] ${($settings?.highContrastMode ?? false) ? 'text-gray-800 dark:text-gray-100' : 'text-gray-500'}`}