Browse Source

feat: oauth2.1 mcp integration

Timothy Jaeryang Baek 2 weeks ago
parent
commit
77e971dd9f

+ 57 - 3
backend/open_webui/main.py

@@ -473,7 +473,12 @@ from open_webui.utils.auth import (
     get_verified_user,
 )
 from open_webui.utils.plugin import install_tool_and_function_dependencies
-from open_webui.utils.oauth import OAuthManager
+from open_webui.utils.oauth import (
+    OAuthManager,
+    OAuthClientManager,
+    decrypt_data,
+    OAuthClientInformationFull,
+)
 from open_webui.utils.security_headers import SecurityHeadersMiddleware
 from open_webui.utils.redis import get_redis_connection
 
@@ -603,9 +608,14 @@ app = FastAPI(
     lifespan=lifespan,
 )
 
+# For Open WebUI OIDC/OAuth2
 oauth_manager = OAuthManager(app)
 app.state.oauth_manager = oauth_manager
 
+# For Integrations
+oauth_client_manager = OAuthClientManager(app)
+app.state.oauth_client_manager = oauth_client_manager
+
 app.state.instance_id = None
 app.state.config = AppConfig(
     redis_url=REDIS_URL,
@@ -1881,6 +1891,24 @@ async def get_current_usage(user=Depends(get_verified_user)):
 # OAuth Login & Callback
 ############################
 
+
+# Initialize OAuth client manager with any MCP tool servers using OAuth 2.1
+if len(app.state.config.TOOL_SERVER_CONNECTIONS) > 0:
+    for tool_server_connection in app.state.config.TOOL_SERVER_CONNECTIONS:
+        if tool_server_connection.get("type", "openapi") == "mcp":
+            server_id = tool_server_connection.get("info", {}).get("id")
+            auth_type = tool_server_connection.get("auth_type", "none")
+            if server_id and auth_type == "oauth_2.1":
+                oauth_client_info = tool_server_connection.get("info", {}).get(
+                    "oauth_client_info"
+                )
+
+                oauth_client_info = decrypt_data(oauth_client_info)
+                app.state.oauth_client_manager.add_client(
+                    f"mcp:{server_id}", OAuthClientInformationFull(**oauth_client_info)
+                )
+
+
 # SessionMiddleware is used by authlib for oauth
 if len(OAUTH_PROVIDERS) > 0:
     try:
@@ -1913,6 +1941,31 @@ if len(OAUTH_PROVIDERS) > 0:
         )
 
 
+@app.get("/oauth/clients/{client_id}/authorize")
+async def oauth_client_authorize(
+    client_id: str,
+    request: Request,
+    response: Response,
+    user=Depends(get_verified_user),
+):
+    return await oauth_client_manager.handle_authorize(request, client_id=client_id)
+
+
+@app.get("/oauth/clients/{client_id}/callback")
+async def oauth_client_callback(
+    client_id: str,
+    request: Request,
+    response: Response,
+    user=Depends(get_verified_user),
+):
+    return await oauth_client_manager.handle_callback(
+        request,
+        client_id=client_id,
+        user_id=user.id if user else None,
+        response=response,
+    )
+
+
 @app.get("/oauth/{provider}/login")
 async def oauth_login(provider: str, request: Request):
     return await oauth_manager.handle_login(request, provider)
@@ -1924,8 +1977,9 @@ async def oauth_login(provider: str, request: Request):
 #    - This is considered insecure in general, as OAuth providers do not always verify email addresses
 # 3. If there is no user, and ENABLE_OAUTH_SIGNUP is true, create a user
 #    - Email addresses are considered unique, so we fail registration if the email address is already taken
-@app.get("/oauth/{provider}/callback")
-async def oauth_callback(provider: str, request: Request, response: Response):
+@app.get("/oauth/{provider}/callback")  # Legacy endpoint
+@app.get("/oauth/{provider}/login/callback")
+async def oauth_login_callback(provider: str, request: Request, response: Response):
     return await oauth_manager.handle_callback(request, provider, response)
 
 

+ 20 - 0
backend/open_webui/models/oauth_sessions.py

@@ -176,6 +176,26 @@ class OAuthSessionTable:
             log.error(f"Error getting OAuth session by ID: {e}")
             return None
 
+    def get_session_by_provider_and_user_id(
+        self, provider: str, user_id: str
+    ) -> Optional[OAuthSessionModel]:
+        """Get OAuth session by provider and user ID"""
+        try:
+            with get_db() as db:
+                session = (
+                    db.query(OAuthSession)
+                    .filter_by(provider=provider, user_id=user_id)
+                    .first()
+                )
+                if session:
+                    session.token = self._decrypt_token(session.token)
+                    return OAuthSessionModel.model_validate(session)
+
+                return None
+        except Exception as e:
+            log.error(f"Error getting OAuth session by provider and user ID: {e}")
+            return None
+
     def get_sessions_by_user_id(self, user_id: str) -> List[OAuthSessionModel]:
         """Get all OAuth sessions for a user"""
         try:

+ 27 - 3
backend/open_webui/routers/configs.py

@@ -21,7 +21,9 @@ from open_webui.env import SRC_LOG_LEVELS
 from open_webui.utils.oauth import (
     get_discovery_urls,
     get_oauth_client_info_with_dynamic_client_registration,
-    encrypt_token,
+    encrypt_data,
+    decrypt_data,
+    OAuthClientInformationFull,
 )
 from mcp.shared.auth import OAuthMetadata
 
@@ -103,17 +105,22 @@ class OAuthClientRegistrationForm(BaseModel):
 async def register_oauth_client(
     request: Request,
     form_data: OAuthClientRegistrationForm,
+    type: Optional[str] = None,
     user=Depends(get_admin_user),
 ):
     try:
+        oauth_client_id = form_data.client_id
+        if type:
+            oauth_client_id = f"{type}:{form_data.client_id}"
+
         oauth_client_info = (
             await get_oauth_client_info_with_dynamic_client_registration(
-                request, form_data.url
+                request, oauth_client_id, form_data.url
             )
         )
         return {
             "status": True,
-            "oauth_client_info": encrypt_token(
+            "oauth_client_info": encrypt_data(
                 oauth_client_info.model_dump(mode="json")
             ),
         }
@@ -161,8 +168,25 @@ async def set_tool_servers_config(
     request.app.state.config.TOOL_SERVER_CONNECTIONS = [
         connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
     ]
+
     await set_tool_servers(request)
 
+    for connection in request.app.state.config.TOOL_SERVER_CONNECTIONS:
+        server_type = connection.get("type", "openapi")
+        if server_type == "mcp":
+            server_id = connection.get("info", {}).get("id")
+            auth_type = connection.get("auth_type", "none")
+            if auth_type == "oauth_2.1" and server_id:
+                try:
+                    oauth_client_info = decrypt_data(oauth_client_info)
+                    await request.app.state.oauth_client_manager.add_client(
+                        f"{server_type}:{server_id}",
+                        OAuthClientInformationFull(**oauth_client_info),
+                    )
+                except Exception as e:
+                    log.debug(f"Failed to add OAuth client for MCP tool server: {e}")
+                    continue
+
     return {
         "TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS,
     }

+ 26 - 0
backend/open_webui/routers/tools.py

@@ -9,6 +9,7 @@ from pydantic import BaseModel, HttpUrl
 from fastapi import APIRouter, Depends, HTTPException, Request, status
 
 
+from open_webui.models.oauth_sessions import OAuthSessions
 from open_webui.models.tools import (
     ToolForm,
     ToolModel,
@@ -80,6 +81,24 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
     # MCP Tool Servers
     for server in request.app.state.config.TOOL_SERVER_CONNECTIONS:
         if server.get("type", "openapi") == "mcp":
+            server_id = server.get("info", {}).get("id")
+            auth_type = server.get("auth_type", "none")
+
+            session_token = None
+            if auth_type == "oauth_2.1":
+                splits = server_id.split(":")
+                server_id = splits[-1] if len(splits) > 1 else server_id
+
+                session_token = (
+                    await request.app.state.oauth_client_manager.get_oauth_token(
+                        user.id, f"mcp:{server_id}"
+                    )
+                )
+
+                print("User ID:", user.id)
+                print("Server ID:", server_id)
+                print("MCP Session Token:", session_token)
+
             tools.append(
                 ToolUserResponse(
                     **{
@@ -96,6 +115,13 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
                         ),
                         "updated_at": int(time.time()),
                         "created_at": int(time.time()),
+                        **(
+                            {
+                                "authenticated": session_token is not None,
+                            }
+                            if auth_type == "oauth_2.1"
+                            else {}
+                        ),
                     }
                 )
             )

+ 17 - 0
backend/open_webui/utils/middleware.py

@@ -24,6 +24,7 @@ from fastapi.responses import HTMLResponse
 from starlette.responses import Response, StreamingResponse, JSONResponse
 
 
+from open_webui.models.oauth_sessions import OAuthSessions
 from open_webui.models.chats import Chats
 from open_webui.models.folders import Folders
 from open_webui.models.users import Users
@@ -1047,6 +1048,22 @@ async def process_chat_payload(request, form_data, user, metadata, model):
                             headers["Authorization"] = (
                                 f"Bearer {oauth_token.get('access_token', '')}"
                             )
+                    elif auth_type == "oauth_2.1":
+                        try:
+                            splits = server_id.split(":")
+                            server_id = splits[-1] if len(splits) > 1 else server_id
+
+                            oauth_token = await request.app.state.oauth_client_manager.get_oauth_token(
+                                user.id, f"mcp:{server_id}"
+                            )
+
+                            if oauth_token:
+                                headers["Authorization"] = (
+                                    f"Bearer {oauth_token.get('access_token', '')}"
+                                )
+                        except Exception as e:
+                            log.error(f"Error getting OAuth token: {e}")
+                            oauth_token = None
 
                     mcp_client = MCPClient()
                     await mcp_client.connect(

+ 45 - 37
backend/open_webui/utils/oauth.py

@@ -126,24 +126,24 @@ except Exception as e:
     raise
 
 
-def encrypt_token(token) -> str:
-    """Encrypt OAuth tokens for storage"""
+def encrypt_data(data) -> str:
+    """Encrypt data for storage"""
     try:
-        token_json = json.dumps(token)
-        encrypted = FERNET.encrypt(token_json.encode()).decode()
+        data_json = json.dumps(data)
+        encrypted = FERNET.encrypt(data_json.encode()).decode()
         return encrypted
     except Exception as e:
-        log.error(f"Error encrypting tokens: {e}")
+        log.error(f"Error encrypting data: {e}")
         raise
 
 
-def decrypt_token(token: str):
-    """Decrypt OAuth tokens from storage"""
+def decrypt_data(data: str):
+    """Decrypt data from storage"""
     try:
-        decrypted = FERNET.decrypt(token.encode()).decode()
+        decrypted = FERNET.decrypt(data.encode()).decode()
         return json.loads(decrypted)
     except Exception as e:
-        log.error(f"Error decrypting tokens: {e}")
+        log.error(f"Error decrypting data: {e}")
         raise
 
 
@@ -212,7 +212,10 @@ def get_discovery_urls(server_url) -> list[str]:
 # TODO: Some OAuth providers require Initial Access Tokens (IATs) for dynamic client registration.
 # This is not currently supported.
 async def get_oauth_client_info_with_dynamic_client_registration(
-    request, oauth_server_url, oauth_server_key: Optional[str] = None
+    request,
+    client_id: str,
+    oauth_server_url: str,
+    oauth_server_key: Optional[str] = None,
 ) -> OAuthClientInformationFull:
     try:
         oauth_server_metadata = None
@@ -221,9 +224,10 @@ async def get_oauth_client_info_with_dynamic_client_registration(
         redirect_base_url = (
             str(request.app.state.config.WEBUI_URL or request.base_url)
         ).rstrip("/")
+
         oauth_client_metadata = OAuthClientMetadata(
             client_name="Open WebUI",
-            redirect_uris=[f"{redirect_base_url}/oauth/callback"],
+            redirect_uris=[f"{redirect_base_url}/oauth/clients/{client_id}/callback"],
             grant_types=["authorization_code", "refresh_token"],
             response_types=["code"],
             token_endpoint_auth_method="client_secret_post",
@@ -315,23 +319,22 @@ class OAuthClientManager:
         self.clients = {}
 
     def add_client(self, client_id, oauth_client_info: OAuthClientInformationFull):
-        if client_id not in self.clients:
-            self.clients[client_id] = {
-                "client": self.oauth.register(
-                    name=client_id,
-                    client_id=oauth_client_info.client_id,
-                    client_secret=oauth_client_info.client_secret,
-                    client_kwargs=(
-                        {"scope": oauth_client_info.scope}
-                        if oauth_client_info.scope
-                        else {}
-                    ),
-                    server_metadata_url=(
-                        oauth_client_info.issuer if oauth_client_info.issuer else None
-                    ),
+        self.clients[client_id] = {
+            "client": self.oauth.register(
+                name=client_id,
+                client_id=oauth_client_info.client_id,
+                client_secret=oauth_client_info.client_secret,
+                client_kwargs=(
+                    {"scope": oauth_client_info.scope}
+                    if oauth_client_info.scope
+                    else {}
                 ),
-                "client_info": oauth_client_info,
-            }
+                server_metadata_url=(
+                    oauth_client_info.issuer if oauth_client_info.issuer else None
+                ),
+            ),
+            "client_info": oauth_client_info,
+        }
         return self.clients[client_id]
 
     def remove_client(self, client_id):
@@ -359,7 +362,7 @@ class OAuthClientManager:
         return None
 
     async def get_oauth_token(
-        self, user_id: str, session_id: str, force_refresh: bool = False
+        self, user_id: str, client_id: str, force_refresh: bool = False
     ):
         """
         Get a valid OAuth token for the user, automatically refreshing if needed.
@@ -374,10 +377,12 @@ class OAuthClientManager:
         """
         try:
             # Get the OAuth session
-            session = OAuthSessions.get_session_by_id_and_user_id(session_id, user_id)
+            session = OAuthSessions.get_session_by_provider_and_user_id(
+                client_id, user_id
+            )
             if not session:
                 log.warning(
-                    f"No OAuth session found for user {user_id}, session {session_id}"
+                    f"No OAuth session found for user {user_id}, client_id {client_id}"
                 )
                 return None
 
@@ -392,8 +397,9 @@ class OAuthClientManager:
                     return refreshed_token
                 else:
                     log.warning(
-                        f"Token refresh failed for user {user_id}, client_id {session.provider}"
+                        f"Token refresh failed for user {user_id}, client_id {session.provider}, deleting session {session.id}"
                     )
+                    OAuthSessions.delete_session_by_id(session.id)
                     return None
             return session.token
 
@@ -533,7 +539,7 @@ class OAuthClientManager:
         redirect_uri = (
             client_info.redirect_uris[0] if client_info.redirect_uris else None
         )
-        return await client.authorize_redirect(request, redirect_uri)
+        return await client.authorize_redirect(request, str(redirect_uri))
 
     async def handle_callback(self, request, client_id: str, user_id: str, response):
         client = self.get_client(client_id)
@@ -565,7 +571,6 @@ class OAuthClientManager:
                         provider=client_id,
                         token=token,
                     )
-
                     log.info(
                         f"Stored OAuth session server-side for user {user_id}, client_id {client_id}"
                     )
@@ -579,16 +584,17 @@ class OAuthClientManager:
             error_message = "OAuth callback error"
             log.warning(f"OAuth callback error: {e}")
 
-        redirect_base_url = (
+        redirect_url = (
             str(request.app.state.config.WEBUI_URL or request.base_url)
         ).rstrip("/")
-        redirect_url = f"{redirect_base_url}/auth"
 
         if error_message:
-            redirect_url = f"{redirect_url}?error={error_message}"
+            log.debug(error_message)
+            redirect_url = f"{redirect_url}/?error={error_message}"
             return RedirectResponse(url=redirect_url, headers=response.headers)
 
         response = RedirectResponse(url=redirect_url, headers=response.headers)
+        return response
 
 
 class OAuthManager:
@@ -649,8 +655,10 @@ class OAuthManager:
                     return refreshed_token
                 else:
                     log.warning(
-                        f"Token refresh failed for user {user_id}, provider {session.provider}"
+                        f"Token refresh failed for user {user_id}, provider {session.provider}, deleting session {session.id}"
                     )
+                    OAuthSessions.delete_session_by_id(session.id)
+
                     return None
             return session.token
 

+ 13 - 3
src/lib/apis/configs/index.ts

@@ -1,4 +1,4 @@
-import { WEBUI_API_BASE_URL } from '$lib/constants';
+import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants';
 import type { Banner } from '$lib/types';
 
 export const importConfig = async (token: string, config) => {
@@ -208,10 +208,15 @@ type RegisterOAuthClientForm = {
 	client_name?: string;
 };
 
-export const registerOAuthClient = async (token: string, formData: RegisterOAuthClientForm) => {
+export const registerOAuthClient = async (
+	token: string,
+	formData: RegisterOAuthClientForm,
+	type: null | string = null
+) => {
 	let error = null;
 
-	const res = await fetch(`${WEBUI_API_BASE_URL}/configs/oauth/clients/register`, {
+	const searchParams = type ? `?type=${type}` : '';
+	const res = await fetch(`${WEBUI_API_BASE_URL}/configs/oauth/clients/register${searchParams}`, {
 		method: 'POST',
 		headers: {
 			'Content-Type': 'application/json',
@@ -238,6 +243,11 @@ export const registerOAuthClient = async (token: string, formData: RegisterOAuth
 	return res;
 };
 
+export const getOAuthClientAuthorizationUrl = (clientId: string, type: null | string = null) => {
+	const oauthClientId = type ? `${type}:${clientId}` : clientId;
+	return `${WEBUI_BASE_URL}/oauth/clients/${oauthClientId}/authorize`;
+};
+
 export const getCodeExecutionConfig = async (token: string) => {
 	let error = null;
 

+ 14 - 4
src/lib/components/AddToolServerModal.svelte

@@ -57,16 +57,26 @@
 			return;
 		}
 
-		const res = await registerOAuthClient(localStorage.token, {
-			url: url,
-			client_id: id
-		}).catch((err) => {
+		const res = await registerOAuthClient(
+			localStorage.token,
+			{
+				url: url,
+				client_id: id
+			},
+			'mcp'
+		).catch((err) => {
 			toast.error($i18n.t('Registration failed'));
 			return null;
 		});
 
 		if (res) {
+			toast.warning(
+				$i18n.t(
+					'Please save the connection to persist the OAuth client information and do not change the ID'
+				)
+			);
 			toast.success($i18n.t('Registration successful'));
+
 			console.debug('Registration successful', res);
 			oauthClientInfo = res?.oauth_client_info ?? null;
 		}

+ 19 - 3
src/lib/components/chat/MessageInput/IntegrationsMenu.svelte

@@ -20,6 +20,8 @@
 	import ChevronRight from '$lib/components/icons/ChevronRight.svelte';
 	import ChevronLeft from '$lib/components/icons/ChevronLeft.svelte';
 	import ValvesModal from '$lib/components/workspace/common/ValvesModal.svelte';
+	import { getOAuthClientAuthorizationUrl } from '$lib/apis/configs';
+	import { partition } from 'd3-hierarchy';
 
 	const i18n = getContext('i18n');
 
@@ -321,11 +323,25 @@
 
 					{#each Object.keys(tools) as toolId}
 						<button
-							class="flex w-full justify-between gap-2 items-center px-3 py-1.5 text-sm cursor-pointer rounded-xl hover:bg-gray-50 dark:hover:bg-gray-800/50"
-							on:click={() => {
-								tools[toolId].enabled = !tools[toolId].enabled;
+							class="relative flex w-full justify-between gap-2 items-center px-3 py-1.5 text-sm cursor-pointer rounded-xl hover:bg-gray-50 dark:hover:bg-gray-800/50"
+							on:click={(e) => {
+								if (!(tools[toolId]?.authenticated ?? true)) {
+									e.preventDefault();
+
+									let parts = toolId.split(':');
+									let serverId = parts?.at(-1) ?? toolId;
+
+									const authUrl = getOAuthClientAuthorizationUrl(serverId, 'mcp');
+									window.open(authUrl, '_blank', 'noopener');
+								} else {
+									tools[toolId].enabled = !tools[toolId].enabled;
+								}
 							}}
 						>
+							{#if !(tools[toolId]?.authenticated ?? true)}
+								<!-- make it slighly darker and not clickable -->
+								<div class="absolute inset-0 opacity-50 rounded-xl cursor-not-allowed z-10" />
+							{/if}
 							<div class="flex-1 truncate">
 								<div class="flex flex-1 gap-2 items-center">
 									<Tooltip content={tools[toolId]?.name ?? ''} placement="top">

+ 10 - 0
src/routes/(app)/+page.svelte

@@ -1,5 +1,15 @@
 <script lang="ts">
+	import { onMount } from 'svelte';
+	import { toast } from 'svelte-sonner';
+
 	import Chat from '$lib/components/chat/Chat.svelte';
+	import { page } from '$app/stores';
+
+	onMount(() => {
+		if ($page.url.searchParams.get('error')) {
+			toast.error($page.url.searchParams.get('error') || 'An unknown error occurred.');
+		}
+	});
 </script>
 
 <Chat />