Просмотр исходного кода

Merge pull request #18415 from taylorwilsdon/oauth_error_handling_enh

enh: More detailed OAuth2.1 tool callback error handling + fix for editing existing tools
Tim Baek 3 месяцев назад
Родитель
Сommit
bfadbc9934

+ 11 - 0
backend/open_webui/models/oauth_sessions.py

@@ -262,5 +262,16 @@ class OAuthSessionTable:
             log.error(f"Error deleting OAuth sessions by user ID: {e}")
             log.error(f"Error deleting OAuth sessions by user ID: {e}")
             return False
             return False
 
 
+    def delete_sessions_by_provider(self, provider: str) -> bool:
+        """Delete all OAuth sessions for a provider"""
+        try:
+            with get_db() as db:
+                db.query(OAuthSession).filter_by(provider=provider).delete()
+                db.commit()
+                return True
+        except Exception as e:
+            log.error(f"Error deleting OAuth sessions by provider {provider}: {e}")
+            return False
+
 
 
 OAuthSessions = OAuthSessionTable()
 OAuthSessions = OAuthSessionTable()

+ 50 - 1
backend/open_webui/routers/configs.py

@@ -1,4 +1,5 @@
 import logging
 import logging
+import copy
 from fastapi import APIRouter, Depends, Request, HTTPException
 from fastapi import APIRouter, Depends, Request, HTTPException
 from pydantic import BaseModel, ConfigDict
 from pydantic import BaseModel, ConfigDict
 import aiohttp
 import aiohttp
@@ -15,6 +16,7 @@ from open_webui.utils.tools import (
     set_tool_servers,
     set_tool_servers,
 )
 )
 from open_webui.utils.mcp.client import MCPClient
 from open_webui.utils.mcp.client import MCPClient
+from open_webui.models.oauth_sessions import OAuthSessions
 
 
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 
 
@@ -165,12 +167,59 @@ async def set_tool_servers_config(
     form_data: ToolServersConfigForm,
     form_data: ToolServersConfigForm,
     user=Depends(get_admin_user),
     user=Depends(get_admin_user),
 ):
 ):
-    request.app.state.config.TOOL_SERVER_CONNECTIONS = [
+    old_connections = copy.deepcopy(
+        request.app.state.config.TOOL_SERVER_CONNECTIONS or []
+    )
+
+    new_connections = [
         connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
         connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
     ]
     ]
 
 
+    old_mcp_connections = {
+        conn.get("info", {}).get("id"): conn
+        for conn in old_connections
+        if conn.get("type") == "mcp"
+    }
+    new_mcp_connections = {
+        conn.get("info", {}).get("id"): conn
+        for conn in new_connections
+        if conn.get("type") == "mcp"
+    }
+
+    purge_oauth_clients = set()
+
+    for server_id, old_conn in old_mcp_connections.items():
+        if not server_id:
+            continue
+
+        old_auth_type = old_conn.get("auth_type", "none")
+        new_conn = new_mcp_connections.get(server_id)
+
+        if new_conn is None:
+            if old_auth_type == "oauth_2.1":
+                purge_oauth_clients.add(server_id)
+            continue
+
+        new_auth_type = new_conn.get("auth_type", "none")
+
+        if old_auth_type == "oauth_2.1":
+            if (
+                new_auth_type != "oauth_2.1"
+                or old_conn.get("url") != new_conn.get("url")
+                or old_conn.get("info", {}).get("oauth_client_info")
+                != new_conn.get("info", {}).get("oauth_client_info")
+            ):
+                purge_oauth_clients.add(server_id)
+
+    request.app.state.config.TOOL_SERVER_CONNECTIONS = new_connections
+
     await set_tool_servers(request)
     await set_tool_servers(request)
 
 
+    for server_id in purge_oauth_clients:
+        client_key = f"mcp:{server_id}"
+        request.app.state.oauth_client_manager.remove_client(client_key)
+        OAuthSessions.delete_sessions_by_provider(client_key)
+
     for connection in request.app.state.config.TOOL_SERVER_CONNECTIONS:
     for connection in request.app.state.config.TOOL_SERVER_CONNECTIONS:
         server_type = connection.get("type", "openapi")
         server_type = connection.get("type", "openapi")
         if server_type == "mcp":
         if server_type == "mcp":

+ 266 - 6
backend/open_webui/utils/oauth.py

@@ -1,4 +1,5 @@
 import base64
 import base64
+import copy
 import hashlib
 import hashlib
 import logging
 import logging
 import mimetypes
 import mimetypes
@@ -74,6 +75,8 @@ from mcp.shared.auth import (
     OAuthMetadata,
     OAuthMetadata,
 )
 )
 
 
+from authlib.oauth2.rfc6749.errors import OAuth2Error
+
 
 
 class OAuthClientInformationFull(OAuthClientMetadata):
 class OAuthClientInformationFull(OAuthClientMetadata):
     issuer: Optional[str] = None  # URL of the OAuth server that issued this client
     issuer: Optional[str] = None  # URL of the OAuth server that issued this client
@@ -150,6 +153,37 @@ def decrypt_data(data: str):
         raise
         raise
 
 
 
 
+def _build_oauth_callback_error_message(exc: Exception) -> str:
+    """
+    Produce a user-facing callback error string with actionable context.
+    Keeps the message short and strips newlines for safe redirect usage.
+    """
+    if isinstance(exc, OAuth2Error):
+        parts = [p for p in [exc.error, exc.description] if p]
+        detail = " - ".join(parts)
+    elif isinstance(exc, HTTPException):
+        detail = exc.detail if isinstance(exc.detail, str) else str(exc.detail)
+    elif isinstance(exc, aiohttp.ClientResponseError):
+        detail = f"Upstream provider returned {exc.status}: {exc.message}"
+    elif isinstance(exc, aiohttp.ClientError):
+        detail = str(exc)
+    elif isinstance(exc, KeyError):
+        missing = str(exc).strip("'")
+        if missing.lower() == "state":
+            detail = "Missing state parameter in callback (session may have expired)"
+        else:
+            detail = f"Missing expected key '{missing}' in OAuth response"
+    else:
+        detail = str(exc)
+
+    detail = detail.replace("\n", " ").strip()
+    if not detail:
+        detail = exc.__class__.__name__
+
+    message = f"OAuth callback failed: {detail}"
+    return message[:197] + "..." if len(message) > 200 else message
+
+
 def is_in_blocked_groups(group_name: str, groups: list) -> bool:
 def is_in_blocked_groups(group_name: str, groups: list) -> bool:
     """
     """
     Check if a group name matches any blocked pattern.
     Check if a group name matches any blocked pattern.
@@ -368,11 +402,221 @@ class OAuthClientManager:
         return self.clients[client_id]
         return self.clients[client_id]
 
 
     def remove_client(self, client_id):
     def remove_client(self, client_id):
+        removed = False
         if client_id in self.clients:
         if client_id in self.clients:
             del self.clients[client_id]
             del self.clients[client_id]
+            removed = True
+        if hasattr(self.oauth, "_clients"):
+            if client_id in self.oauth._clients:
+                self.oauth._clients.pop(client_id, None)
+                removed = True
+        if hasattr(self.oauth, "_registry"):
+            if client_id in self.oauth._registry:
+                self.oauth._registry.pop(client_id, None)
+                removed = True
+        if removed:
             log.info(f"Removed OAuth client {client_id}")
             log.info(f"Removed OAuth client {client_id}")
         return True
         return True
 
 
+    def _find_mcp_connection(self, request, client_id: str):
+        try:
+            connections = request.app.state.config.TOOL_SERVER_CONNECTIONS or []
+        except Exception:
+            connections = []
+
+        normalized_client_id = client_id.split(":")[-1]
+
+        for idx, connection in enumerate(connections):
+            if not isinstance(connection, dict):
+                continue
+            if connection.get("type") != "mcp":
+                continue
+
+            info = connection.get("info") or {}
+            server_id = info.get("id")
+            if not server_id:
+                continue
+
+            normalized_server_id = server_id.split(":")[-1]
+            if normalized_server_id == normalized_client_id:
+                return idx, connection
+
+        return None, None
+
+    async def _preflight_authorization_url(
+        self, client, client_info: OAuthClientInformationFull
+    ) -> bool:
+        # Only perform preflight checks for Starlette OAuth clients
+        if not hasattr(client, "create_authorization_url"):
+            return True
+
+        redirect_uri = None
+        if client_info.redirect_uris:
+            redirect_uri = str(client_info.redirect_uris[0])
+
+        try:
+            auth_data = await client.create_authorization_url(redirect_uri=redirect_uri)
+            authorize_url = auth_data.get("url")
+            if not authorize_url:
+                return True
+        except Exception as e:
+            log.debug(
+                "Skipping OAuth preflight for client %s: %s",
+                client_info.client_id,
+                e,
+            )
+            return True
+
+        try:
+            async with aiohttp.ClientSession(trust_env=True) as session:
+                async with session.get(
+                    authorize_url,
+                    allow_redirects=False,
+                    ssl=AIOHTTP_CLIENT_SESSION_SSL,
+                ) as resp:
+                    if resp.status < 400:
+                        return True
+
+                    body_text = await resp.text()
+                    error = None
+                    error_description = ""
+                    content_type = resp.headers.get("content-type", "")
+
+                    if "application/json" in content_type:
+                        try:
+                            payload = json.loads(body_text)
+                            error = payload.get("error")
+                            error_description = payload.get("error_description", "")
+                        except json.JSONDecodeError:
+                            error = None
+                            error_description = ""
+                    else:
+                        error_description = body_text
+
+                    combined = f"{error or ''} {error_description}".lower()
+                    if (
+                        "invalid_client" in combined
+                        or "invalid client" in combined
+                        or "client id" in combined
+                    ):
+                        log.warning(
+                            "OAuth client preflight detected invalid registration for %s: %s %s",
+                            client_info.client_id,
+                            error,
+                            error_description,
+                        )
+                        return False
+        except Exception as e:
+            log.debug(
+                "Skipping OAuth preflight network check for client %s: %s",
+                client_info.client_id,
+                e,
+            )
+
+        return True
+
+    async def _re_register_client(self, request, client_id: str) -> bool:
+        idx, connection = self._find_mcp_connection(request, client_id)
+        if idx is None or connection is None:
+            log.warning(
+                "Unable to locate MCP tool server configuration for client %s during re-registration",
+                client_id,
+            )
+            return False
+
+        server_url = connection.get("url")
+        oauth_server_key = (connection.get("config") or {}).get("oauth_server_key")
+
+        try:
+            oauth_client_info = (
+                await get_oauth_client_info_with_dynamic_client_registration(
+                    request,
+                    client_id,
+                    server_url,
+                    oauth_server_key,
+                )
+            )
+        except Exception as e:
+            log.error(
+                "Dynamic client re-registration failed for %s: %s",
+                client_id,
+                e,
+            )
+            return False
+
+        encrypted_info = encrypt_data(oauth_client_info.model_dump(mode="json"))
+
+        updated_connections = copy.deepcopy(
+            request.app.state.config.TOOL_SERVER_CONNECTIONS or []
+        )
+        if idx >= len(updated_connections):
+            log.error(
+                "MCP tool server index %s out of range during OAuth client re-registration for %s",
+                idx,
+                client_id,
+            )
+            return False
+
+        updated_connection = copy.deepcopy(connection)
+        updated_connection.setdefault("info", {})
+        updated_connection["info"]["oauth_client_info"] = encrypted_info
+        updated_connections[idx] = updated_connection
+
+        try:
+            request.app.state.config.TOOL_SERVER_CONNECTIONS = updated_connections
+        except Exception as e:
+            log.error(
+                "Failed to persist updated OAuth client info for %s: %s",
+                client_id,
+                e,
+            )
+            return False
+
+        self.remove_client(client_id)
+        self.add_client(client_id, oauth_client_info)
+        OAuthSessions.delete_sessions_by_provider(client_id)
+
+        log.info("Re-registered OAuth client %s for MCP tool server", client_id)
+        return True
+
+    async def _ensure_valid_client_registration(self, request, client_id: str) -> None:
+        if not client_id.startswith("mcp:"):
+            return
+
+        client = self.get_client(client_id)
+        client_info = self.get_client_info(client_id)
+        if client is None or client_info is None:
+            raise HTTPException(status.HTTP_404_NOT_FOUND)
+
+        is_valid = await self._preflight_authorization_url(client, client_info)
+        if is_valid:
+            return
+
+        log.info(
+            "Detected invalid OAuth client %s; attempting re-registration",
+            client_id,
+        )
+        re_registered = await self._re_register_client(request, client_id)
+        if not re_registered:
+            raise HTTPException(
+                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+                detail="Failed to re-register OAuth client",
+            )
+
+        client = self.get_client(client_id)
+        client_info = self.get_client_info(client_id)
+        if client is None or client_info is None:
+            raise HTTPException(
+                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+                detail="OAuth client unavailable after re-registration",
+            )
+
+        if not await self._preflight_authorization_url(client, client_info):
+            raise HTTPException(
+                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+                detail="OAuth client registration is still invalid after re-registration",
+            )
+
     def get_client(self, client_id):
     def get_client(self, client_id):
         client = self.clients.get(client_id)
         client = self.clients.get(client_id)
         return client["client"] if client else None
         return client["client"] if client else None
@@ -558,10 +802,11 @@ class OAuthClientManager:
             return None
             return None
 
 
     async def handle_authorize(self, request, client_id: str) -> RedirectResponse:
     async def handle_authorize(self, request, client_id: str) -> RedirectResponse:
+        await self._ensure_valid_client_registration(request, client_id)
+
         client = self.get_client(client_id)
         client = self.get_client(client_id)
         if client is None:
         if client is None:
             raise HTTPException(404)
             raise HTTPException(404)
-
         client_info = self.get_client_info(client_id)
         client_info = self.get_client_info(client_id)
         if client_info is None:
         if client_info is None:
             raise HTTPException(404)
             raise HTTPException(404)
@@ -569,7 +814,8 @@ class OAuthClientManager:
         redirect_uri = (
         redirect_uri = (
             client_info.redirect_uris[0] if client_info.redirect_uris else None
             client_info.redirect_uris[0] if client_info.redirect_uris else None
         )
         )
-        return await client.authorize_redirect(request, str(redirect_uri))
+        redirect_uri_str = str(redirect_uri) if redirect_uri else None
+        return await client.authorize_redirect(request, redirect_uri_str)
 
 
     async def handle_callback(self, request, client_id: str, user_id: str, response):
     async def handle_callback(self, request, client_id: str, user_id: str, response):
         client = self.get_client(client_id)
         client = self.get_client(client_id)
@@ -621,8 +867,14 @@ class OAuthClientManager:
                 error_message = "Failed to obtain OAuth token"
                 error_message = "Failed to obtain OAuth token"
                 log.warning(error_message)
                 log.warning(error_message)
         except Exception as e:
         except Exception as e:
-            error_message = "OAuth callback error"
-            log.warning(f"OAuth callback error: {e}")
+            error_message = _build_oauth_callback_error_message(e)
+            log.warning(
+                "OAuth callback error for user_id=%s client_id=%s: %s",
+                user_id,
+                client_id,
+                error_message,
+                exc_info=True,
+            )
 
 
         redirect_url = (
         redirect_url = (
             str(request.app.state.config.WEBUI_URL or request.base_url)
             str(request.app.state.config.WEBUI_URL or request.base_url)
@@ -630,7 +882,9 @@ class OAuthClientManager:
 
 
         if error_message:
         if error_message:
             log.debug(error_message)
             log.debug(error_message)
-            redirect_url = f"{redirect_url}/?error={error_message}"
+            redirect_url = (
+                f"{redirect_url}/?error={urllib.parse.quote_plus(error_message)}"
+            )
             return RedirectResponse(url=redirect_url, headers=response.headers)
             return RedirectResponse(url=redirect_url, headers=response.headers)
 
 
         response = RedirectResponse(url=redirect_url, headers=response.headers)
         response = RedirectResponse(url=redirect_url, headers=response.headers)
@@ -1104,7 +1358,13 @@ class OAuthManager:
             try:
             try:
                 token = await client.authorize_access_token(request)
                 token = await client.authorize_access_token(request)
             except Exception as e:
             except Exception as e:
-                log.warning(f"OAuth callback error: {e}")
+                detailed_error = _build_oauth_callback_error_message(e)
+                log.warning(
+                    "OAuth callback error during authorize_access_token for provider %s: %s",
+                    provider,
+                    detailed_error,
+                    exc_info=True,
+                )
                 raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
                 raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
 
 
             # Try to get userinfo from the token first, some providers include it there
             # Try to get userinfo from the token first, some providers include it there