Просмотр исходного кода

Added a targeted utility to wipe all OAuth sessions for a provider so the cleanup can remove stale access tokens across every user when a connection is updated

Taylor Wilsdon 4 месяцев назад
Родитель
Сommit
c107a3799f
2 измененных файлов с 61 добавлено и 1 удалено
  1. 11 0
      backend/open_webui/models/oauth_sessions.py
  2. 50 1
      backend/open_webui/routers/configs.py

+ 11 - 0
backend/open_webui/models/oauth_sessions.py

@@ -262,5 +262,16 @@ class OAuthSessionTable:
             log.error(f"Error deleting OAuth sessions by user ID: {e}")
             return False
 
+    def delete_sessions_by_provider(self, provider: str) -> bool:
+        """Delete all OAuth sessions for a provider"""
+        try:
+            with get_db() as db:
+                db.query(OAuthSession).filter_by(provider=provider).delete()
+                db.commit()
+                return True
+        except Exception as e:
+            log.error(f"Error deleting OAuth sessions by provider {provider}: {e}")
+            return False
+
 
 OAuthSessions = OAuthSessionTable()

+ 50 - 1
backend/open_webui/routers/configs.py

@@ -1,4 +1,5 @@
 import logging
+import copy
 from fastapi import APIRouter, Depends, Request, HTTPException
 from pydantic import BaseModel, ConfigDict
 import aiohttp
@@ -15,6 +16,7 @@ from open_webui.utils.tools import (
     set_tool_servers,
 )
 from open_webui.utils.mcp.client import MCPClient
+from open_webui.models.oauth_sessions import OAuthSessions
 
 from open_webui.env import SRC_LOG_LEVELS
 
@@ -165,12 +167,59 @@ async def set_tool_servers_config(
     form_data: ToolServersConfigForm,
     user=Depends(get_admin_user),
 ):
-    request.app.state.config.TOOL_SERVER_CONNECTIONS = [
+    old_connections = copy.deepcopy(
+        request.app.state.config.TOOL_SERVER_CONNECTIONS or []
+    )
+
+    new_connections = [
         connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
     ]
 
+    old_mcp_connections = {
+        conn.get("info", {}).get("id"): conn
+        for conn in old_connections
+        if conn.get("type") == "mcp"
+    }
+    new_mcp_connections = {
+        conn.get("info", {}).get("id"): conn
+        for conn in new_connections
+        if conn.get("type") == "mcp"
+    }
+
+    purge_oauth_clients = set()
+
+    for server_id, old_conn in old_mcp_connections.items():
+        if not server_id:
+            continue
+
+        old_auth_type = old_conn.get("auth_type", "none")
+        new_conn = new_mcp_connections.get(server_id)
+
+        if new_conn is None:
+            if old_auth_type == "oauth_2.1":
+                purge_oauth_clients.add(server_id)
+            continue
+
+        new_auth_type = new_conn.get("auth_type", "none")
+
+        if old_auth_type == "oauth_2.1":
+            if (
+                new_auth_type != "oauth_2.1"
+                or old_conn.get("url") != new_conn.get("url")
+                or old_conn.get("info", {}).get("oauth_client_info")
+                != new_conn.get("info", {}).get("oauth_client_info")
+            ):
+                purge_oauth_clients.add(server_id)
+
+    request.app.state.config.TOOL_SERVER_CONNECTIONS = new_connections
+
     await set_tool_servers(request)
 
+    for server_id in purge_oauth_clients:
+        client_key = f"mcp:{server_id}"
+        request.app.state.oauth_client_manager.remove_client(client_key)
+        OAuthSessions.delete_sessions_by_provider(client_key)
+
     for connection in request.app.state.config.TOOL_SERVER_CONNECTIONS:
         server_type = connection.get("type", "openapi")
         if server_type == "mcp":