|
|
@@ -1,4 +1,5 @@
|
|
|
import logging
|
|
|
+import copy
|
|
|
from fastapi import APIRouter, Depends, Request, HTTPException
|
|
|
from pydantic import BaseModel, ConfigDict
|
|
|
import aiohttp
|
|
|
@@ -15,6 +16,7 @@ from open_webui.utils.tools import (
|
|
|
set_tool_servers,
|
|
|
)
|
|
|
from open_webui.utils.mcp.client import MCPClient
|
|
|
+from open_webui.models.oauth_sessions import OAuthSessions
|
|
|
|
|
|
from open_webui.env import SRC_LOG_LEVELS
|
|
|
|
|
|
@@ -165,12 +167,59 @@ async def set_tool_servers_config(
|
|
|
form_data: ToolServersConfigForm,
|
|
|
user=Depends(get_admin_user),
|
|
|
):
|
|
|
- request.app.state.config.TOOL_SERVER_CONNECTIONS = [
|
|
|
+ old_connections = copy.deepcopy(
|
|
|
+ request.app.state.config.TOOL_SERVER_CONNECTIONS or []
|
|
|
+ )
|
|
|
+
|
|
|
+ new_connections = [
|
|
|
connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
|
|
|
]
|
|
|
|
|
|
+ old_mcp_connections = {
|
|
|
+ conn.get("info", {}).get("id"): conn
|
|
|
+ for conn in old_connections
|
|
|
+ if conn.get("type") == "mcp"
|
|
|
+ }
|
|
|
+ new_mcp_connections = {
|
|
|
+ conn.get("info", {}).get("id"): conn
|
|
|
+ for conn in new_connections
|
|
|
+ if conn.get("type") == "mcp"
|
|
|
+ }
|
|
|
+
|
|
|
+ purge_oauth_clients = set()
|
|
|
+
|
|
|
+ for server_id, old_conn in old_mcp_connections.items():
|
|
|
+ if not server_id:
|
|
|
+ continue
|
|
|
+
|
|
|
+ old_auth_type = old_conn.get("auth_type", "none")
|
|
|
+ new_conn = new_mcp_connections.get(server_id)
|
|
|
+
|
|
|
+ if new_conn is None:
|
|
|
+ if old_auth_type == "oauth_2.1":
|
|
|
+ purge_oauth_clients.add(server_id)
|
|
|
+ continue
|
|
|
+
|
|
|
+ new_auth_type = new_conn.get("auth_type", "none")
|
|
|
+
|
|
|
+ if old_auth_type == "oauth_2.1":
|
|
|
+ if (
|
|
|
+ new_auth_type != "oauth_2.1"
|
|
|
+ or old_conn.get("url") != new_conn.get("url")
|
|
|
+ or old_conn.get("info", {}).get("oauth_client_info")
|
|
|
+ != new_conn.get("info", {}).get("oauth_client_info")
|
|
|
+ ):
|
|
|
+ purge_oauth_clients.add(server_id)
|
|
|
+
|
|
|
+ request.app.state.config.TOOL_SERVER_CONNECTIONS = new_connections
|
|
|
+
|
|
|
await set_tool_servers(request)
|
|
|
|
|
|
+ for server_id in purge_oauth_clients:
|
|
|
+ client_key = f"mcp:{server_id}"
|
|
|
+ request.app.state.oauth_client_manager.remove_client(client_key)
|
|
|
+ OAuthSessions.delete_sessions_by_provider(client_key)
|
|
|
+
|
|
|
for connection in request.app.state.config.TOOL_SERVER_CONNECTIONS:
|
|
|
server_type = connection.get("type", "openapi")
|
|
|
if server_type == "mcp":
|