|
@@ -1,4 +1,5 @@
|
|
|
import base64
|
|
import base64
|
|
|
|
|
+import copy
|
|
|
import hashlib
|
|
import hashlib
|
|
|
import logging
|
|
import logging
|
|
|
import mimetypes
|
|
import mimetypes
|
|
@@ -74,6 +75,8 @@ from mcp.shared.auth import (
|
|
|
OAuthMetadata,
|
|
OAuthMetadata,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
+from authlib.oauth2.rfc6749.errors import OAuth2Error
|
|
|
|
|
+
|
|
|
|
|
|
|
|
class OAuthClientInformationFull(OAuthClientMetadata):
|
|
class OAuthClientInformationFull(OAuthClientMetadata):
|
|
|
issuer: Optional[str] = None # URL of the OAuth server that issued this client
|
|
issuer: Optional[str] = None # URL of the OAuth server that issued this client
|
|
@@ -150,6 +153,37 @@ def decrypt_data(data: str):
|
|
|
raise
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+def _build_oauth_callback_error_message(exc: Exception) -> str:
|
|
|
|
|
+ """
|
|
|
|
|
+ Produce a user-facing callback error string with actionable context.
|
|
|
|
|
+ Keeps the message short and strips newlines for safe redirect usage.
|
|
|
|
|
+ """
|
|
|
|
|
+ if isinstance(exc, OAuth2Error):
|
|
|
|
|
+ parts = [p for p in [exc.error, exc.description] if p]
|
|
|
|
|
+ detail = " - ".join(parts)
|
|
|
|
|
+ elif isinstance(exc, HTTPException):
|
|
|
|
|
+ detail = exc.detail if isinstance(exc.detail, str) else str(exc.detail)
|
|
|
|
|
+ elif isinstance(exc, aiohttp.ClientResponseError):
|
|
|
|
|
+ detail = f"Upstream provider returned {exc.status}: {exc.message}"
|
|
|
|
|
+ elif isinstance(exc, aiohttp.ClientError):
|
|
|
|
|
+ detail = str(exc)
|
|
|
|
|
+ elif isinstance(exc, KeyError):
|
|
|
|
|
+ missing = str(exc).strip("'")
|
|
|
|
|
+ if missing.lower() == "state":
|
|
|
|
|
+ detail = "Missing state parameter in callback (session may have expired)"
|
|
|
|
|
+ else:
|
|
|
|
|
+ detail = f"Missing expected key '{missing}' in OAuth response"
|
|
|
|
|
+ else:
|
|
|
|
|
+ detail = str(exc)
|
|
|
|
|
+
|
|
|
|
|
+ detail = detail.replace("\n", " ").strip()
|
|
|
|
|
+ if not detail:
|
|
|
|
|
+ detail = exc.__class__.__name__
|
|
|
|
|
+
|
|
|
|
|
+ message = f"OAuth callback failed: {detail}"
|
|
|
|
|
+ return message[:197] + "..." if len(message) > 200 else message
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
def is_in_blocked_groups(group_name: str, groups: list) -> bool:
|
|
def is_in_blocked_groups(group_name: str, groups: list) -> bool:
|
|
|
"""
|
|
"""
|
|
|
Check if a group name matches any blocked pattern.
|
|
Check if a group name matches any blocked pattern.
|
|
@@ -368,11 +402,221 @@ class OAuthClientManager:
|
|
|
return self.clients[client_id]
|
|
return self.clients[client_id]
|
|
|
|
|
|
|
|
def remove_client(self, client_id):
|
|
def remove_client(self, client_id):
|
|
|
|
|
+ removed = False
|
|
|
if client_id in self.clients:
|
|
if client_id in self.clients:
|
|
|
del self.clients[client_id]
|
|
del self.clients[client_id]
|
|
|
|
|
+ removed = True
|
|
|
|
|
+ if hasattr(self.oauth, "_clients"):
|
|
|
|
|
+ if client_id in self.oauth._clients:
|
|
|
|
|
+ self.oauth._clients.pop(client_id, None)
|
|
|
|
|
+ removed = True
|
|
|
|
|
+ if hasattr(self.oauth, "_registry"):
|
|
|
|
|
+ if client_id in self.oauth._registry:
|
|
|
|
|
+ self.oauth._registry.pop(client_id, None)
|
|
|
|
|
+ removed = True
|
|
|
|
|
+ if removed:
|
|
|
log.info(f"Removed OAuth client {client_id}")
|
|
log.info(f"Removed OAuth client {client_id}")
|
|
|
return True
|
|
return True
|
|
|
|
|
|
|
|
|
|
+ def _find_mcp_connection(self, request, client_id: str):
|
|
|
|
|
+ try:
|
|
|
|
|
+ connections = request.app.state.config.TOOL_SERVER_CONNECTIONS or []
|
|
|
|
|
+ except Exception:
|
|
|
|
|
+ connections = []
|
|
|
|
|
+
|
|
|
|
|
+ normalized_client_id = client_id.split(":")[-1]
|
|
|
|
|
+
|
|
|
|
|
+ for idx, connection in enumerate(connections):
|
|
|
|
|
+ if not isinstance(connection, dict):
|
|
|
|
|
+ continue
|
|
|
|
|
+ if connection.get("type") != "mcp":
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ info = connection.get("info") or {}
|
|
|
|
|
+ server_id = info.get("id")
|
|
|
|
|
+ if not server_id:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ normalized_server_id = server_id.split(":")[-1]
|
|
|
|
|
+ if normalized_server_id == normalized_client_id:
|
|
|
|
|
+ return idx, connection
|
|
|
|
|
+
|
|
|
|
|
+ return None, None
|
|
|
|
|
+
|
|
|
|
|
+ async def _preflight_authorization_url(
|
|
|
|
|
+ self, client, client_info: OAuthClientInformationFull
|
|
|
|
|
+ ) -> bool:
|
|
|
|
|
+ # Only perform preflight checks for Starlette OAuth clients
|
|
|
|
|
+ if not hasattr(client, "create_authorization_url"):
|
|
|
|
|
+ return True
|
|
|
|
|
+
|
|
|
|
|
+ redirect_uri = None
|
|
|
|
|
+ if client_info.redirect_uris:
|
|
|
|
|
+ redirect_uri = str(client_info.redirect_uris[0])
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ auth_data = await client.create_authorization_url(redirect_uri=redirect_uri)
|
|
|
|
|
+ authorize_url = auth_data.get("url")
|
|
|
|
|
+ if not authorize_url:
|
|
|
|
|
+ return True
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ log.debug(
|
|
|
|
|
+ "Skipping OAuth preflight for client %s: %s",
|
|
|
|
|
+ client_info.client_id,
|
|
|
|
|
+ e,
|
|
|
|
|
+ )
|
|
|
|
|
+ return True
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ async with aiohttp.ClientSession(trust_env=True) as session:
|
|
|
|
|
+ async with session.get(
|
|
|
|
|
+ authorize_url,
|
|
|
|
|
+ allow_redirects=False,
|
|
|
|
|
+ ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
|
|
|
|
+ ) as resp:
|
|
|
|
|
+ if resp.status < 400:
|
|
|
|
|
+ return True
|
|
|
|
|
+
|
|
|
|
|
+ body_text = await resp.text()
|
|
|
|
|
+ error = None
|
|
|
|
|
+ error_description = ""
|
|
|
|
|
+ content_type = resp.headers.get("content-type", "")
|
|
|
|
|
+
|
|
|
|
|
+ if "application/json" in content_type:
|
|
|
|
|
+ try:
|
|
|
|
|
+ payload = json.loads(body_text)
|
|
|
|
|
+ error = payload.get("error")
|
|
|
|
|
+ error_description = payload.get("error_description", "")
|
|
|
|
|
+ except json.JSONDecodeError:
|
|
|
|
|
+ error = None
|
|
|
|
|
+ error_description = ""
|
|
|
|
|
+ else:
|
|
|
|
|
+ error_description = body_text
|
|
|
|
|
+
|
|
|
|
|
+ combined = f"{error or ''} {error_description}".lower()
|
|
|
|
|
+ if (
|
|
|
|
|
+ "invalid_client" in combined
|
|
|
|
|
+ or "invalid client" in combined
|
|
|
|
|
+ or "client id" in combined
|
|
|
|
|
+ ):
|
|
|
|
|
+ log.warning(
|
|
|
|
|
+ "OAuth client preflight detected invalid registration for %s: %s %s",
|
|
|
|
|
+ client_info.client_id,
|
|
|
|
|
+ error,
|
|
|
|
|
+ error_description,
|
|
|
|
|
+ )
|
|
|
|
|
+ return False
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ log.debug(
|
|
|
|
|
+ "Skipping OAuth preflight network check for client %s: %s",
|
|
|
|
|
+ client_info.client_id,
|
|
|
|
|
+ e,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ return True
|
|
|
|
|
+
|
|
|
|
|
+ async def _re_register_client(self, request, client_id: str) -> bool:
|
|
|
|
|
+ idx, connection = self._find_mcp_connection(request, client_id)
|
|
|
|
|
+ if idx is None or connection is None:
|
|
|
|
|
+ log.warning(
|
|
|
|
|
+ "Unable to locate MCP tool server configuration for client %s during re-registration",
|
|
|
|
|
+ client_id,
|
|
|
|
|
+ )
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+ server_url = connection.get("url")
|
|
|
|
|
+ oauth_server_key = (connection.get("config") or {}).get("oauth_server_key")
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ oauth_client_info = (
|
|
|
|
|
+ await get_oauth_client_info_with_dynamic_client_registration(
|
|
|
|
|
+ request,
|
|
|
|
|
+ client_id,
|
|
|
|
|
+ server_url,
|
|
|
|
|
+ oauth_server_key,
|
|
|
|
|
+ )
|
|
|
|
|
+ )
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ log.error(
|
|
|
|
|
+ "Dynamic client re-registration failed for %s: %s",
|
|
|
|
|
+ client_id,
|
|
|
|
|
+ e,
|
|
|
|
|
+ )
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+ encrypted_info = encrypt_data(oauth_client_info.model_dump(mode="json"))
|
|
|
|
|
+
|
|
|
|
|
+ updated_connections = copy.deepcopy(
|
|
|
|
|
+ request.app.state.config.TOOL_SERVER_CONNECTIONS or []
|
|
|
|
|
+ )
|
|
|
|
|
+ if idx >= len(updated_connections):
|
|
|
|
|
+ log.error(
|
|
|
|
|
+ "MCP tool server index %s out of range during OAuth client re-registration for %s",
|
|
|
|
|
+ idx,
|
|
|
|
|
+ client_id,
|
|
|
|
|
+ )
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+ updated_connection = copy.deepcopy(connection)
|
|
|
|
|
+ updated_connection.setdefault("info", {})
|
|
|
|
|
+ updated_connection["info"]["oauth_client_info"] = encrypted_info
|
|
|
|
|
+ updated_connections[idx] = updated_connection
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ request.app.state.config.TOOL_SERVER_CONNECTIONS = updated_connections
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ log.error(
|
|
|
|
|
+ "Failed to persist updated OAuth client info for %s: %s",
|
|
|
|
|
+ client_id,
|
|
|
|
|
+ e,
|
|
|
|
|
+ )
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+ self.remove_client(client_id)
|
|
|
|
|
+ self.add_client(client_id, oauth_client_info)
|
|
|
|
|
+ OAuthSessions.delete_sessions_by_provider(client_id)
|
|
|
|
|
+
|
|
|
|
|
+ log.info("Re-registered OAuth client %s for MCP tool server", client_id)
|
|
|
|
|
+ return True
|
|
|
|
|
+
|
|
|
|
|
+ async def _ensure_valid_client_registration(self, request, client_id: str) -> None:
|
|
|
|
|
+ if not client_id.startswith("mcp:"):
|
|
|
|
|
+ return
|
|
|
|
|
+
|
|
|
|
|
+ client = self.get_client(client_id)
|
|
|
|
|
+ client_info = self.get_client_info(client_id)
|
|
|
|
|
+ if client is None or client_info is None:
|
|
|
|
|
+ raise HTTPException(status.HTTP_404_NOT_FOUND)
|
|
|
|
|
+
|
|
|
|
|
+ is_valid = await self._preflight_authorization_url(client, client_info)
|
|
|
|
|
+ if is_valid:
|
|
|
|
|
+ return
|
|
|
|
|
+
|
|
|
|
|
+ log.info(
|
|
|
|
|
+ "Detected invalid OAuth client %s; attempting re-registration",
|
|
|
|
|
+ client_id,
|
|
|
|
|
+ )
|
|
|
|
|
+ re_registered = await self._re_register_client(request, client_id)
|
|
|
|
|
+ if not re_registered:
|
|
|
|
|
+ raise HTTPException(
|
|
|
|
|
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
|
|
|
+ detail="Failed to re-register OAuth client",
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ client = self.get_client(client_id)
|
|
|
|
|
+ client_info = self.get_client_info(client_id)
|
|
|
|
|
+ if client is None or client_info is None:
|
|
|
|
|
+ raise HTTPException(
|
|
|
|
|
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
|
|
|
+ detail="OAuth client unavailable after re-registration",
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ if not await self._preflight_authorization_url(client, client_info):
|
|
|
|
|
+ raise HTTPException(
|
|
|
|
|
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
|
|
|
+ detail="OAuth client registration is still invalid after re-registration",
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
def get_client(self, client_id):
|
|
def get_client(self, client_id):
|
|
|
client = self.clients.get(client_id)
|
|
client = self.clients.get(client_id)
|
|
|
return client["client"] if client else None
|
|
return client["client"] if client else None
|
|
@@ -558,10 +802,11 @@ class OAuthClientManager:
|
|
|
return None
|
|
return None
|
|
|
|
|
|
|
|
async def handle_authorize(self, request, client_id: str) -> RedirectResponse:
|
|
async def handle_authorize(self, request, client_id: str) -> RedirectResponse:
|
|
|
|
|
+ await self._ensure_valid_client_registration(request, client_id)
|
|
|
|
|
+
|
|
|
client = self.get_client(client_id)
|
|
client = self.get_client(client_id)
|
|
|
if client is None:
|
|
if client is None:
|
|
|
raise HTTPException(404)
|
|
raise HTTPException(404)
|
|
|
-
|
|
|
|
|
client_info = self.get_client_info(client_id)
|
|
client_info = self.get_client_info(client_id)
|
|
|
if client_info is None:
|
|
if client_info is None:
|
|
|
raise HTTPException(404)
|
|
raise HTTPException(404)
|
|
@@ -569,7 +814,8 @@ class OAuthClientManager:
|
|
|
redirect_uri = (
|
|
redirect_uri = (
|
|
|
client_info.redirect_uris[0] if client_info.redirect_uris else None
|
|
client_info.redirect_uris[0] if client_info.redirect_uris else None
|
|
|
)
|
|
)
|
|
|
- return await client.authorize_redirect(request, str(redirect_uri))
|
|
|
|
|
|
|
+ redirect_uri_str = str(redirect_uri) if redirect_uri else None
|
|
|
|
|
+ return await client.authorize_redirect(request, redirect_uri_str)
|
|
|
|
|
|
|
|
async def handle_callback(self, request, client_id: str, user_id: str, response):
|
|
async def handle_callback(self, request, client_id: str, user_id: str, response):
|
|
|
client = self.get_client(client_id)
|
|
client = self.get_client(client_id)
|
|
@@ -621,8 +867,14 @@ class OAuthClientManager:
|
|
|
error_message = "Failed to obtain OAuth token"
|
|
error_message = "Failed to obtain OAuth token"
|
|
|
log.warning(error_message)
|
|
log.warning(error_message)
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
- error_message = "OAuth callback error"
|
|
|
|
|
- log.warning(f"OAuth callback error: {e}")
|
|
|
|
|
|
|
+ error_message = _build_oauth_callback_error_message(e)
|
|
|
|
|
+ log.warning(
|
|
|
|
|
+ "OAuth callback error for user_id=%s client_id=%s: %s",
|
|
|
|
|
+ user_id,
|
|
|
|
|
+ client_id,
|
|
|
|
|
+ error_message,
|
|
|
|
|
+ exc_info=True,
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
redirect_url = (
|
|
redirect_url = (
|
|
|
str(request.app.state.config.WEBUI_URL or request.base_url)
|
|
str(request.app.state.config.WEBUI_URL or request.base_url)
|
|
@@ -630,7 +882,9 @@ class OAuthClientManager:
|
|
|
|
|
|
|
|
if error_message:
|
|
if error_message:
|
|
|
log.debug(error_message)
|
|
log.debug(error_message)
|
|
|
- redirect_url = f"{redirect_url}/?error={error_message}"
|
|
|
|
|
|
|
+ redirect_url = (
|
|
|
|
|
+ f"{redirect_url}/?error={urllib.parse.quote_plus(error_message)}"
|
|
|
|
|
+ )
|
|
|
return RedirectResponse(url=redirect_url, headers=response.headers)
|
|
return RedirectResponse(url=redirect_url, headers=response.headers)
|
|
|
|
|
|
|
|
response = RedirectResponse(url=redirect_url, headers=response.headers)
|
|
response = RedirectResponse(url=redirect_url, headers=response.headers)
|
|
@@ -1104,7 +1358,13 @@ class OAuthManager:
|
|
|
try:
|
|
try:
|
|
|
token = await client.authorize_access_token(request)
|
|
token = await client.authorize_access_token(request)
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
- log.warning(f"OAuth callback error: {e}")
|
|
|
|
|
|
|
+ detailed_error = _build_oauth_callback_error_message(e)
|
|
|
|
|
+ log.warning(
|
|
|
|
|
+ "OAuth callback error during authorize_access_token for provider %s: %s",
|
|
|
|
|
+ provider,
|
|
|
|
|
+ detailed_error,
|
|
|
|
|
+ exc_info=True,
|
|
|
|
|
+ )
|
|
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
|
|
|
|
|
|
# Try to get userinfo from the token first, some providers include it there
|
|
# Try to get userinfo from the token first, some providers include it there
|