Timothy Jaeryang Baek 2 тижнів тому
батько
коміт
61f20acf61

+ 30 - 25
backend/open_webui/routers/configs.py

@@ -133,39 +133,44 @@ async def verify_tool_servers_config(
     try:
         if form_data.type == "mcp":
             try:
-                async with MCPClient() as client:
-                    auth = None
-                    headers = None
-
-                    token = None
-                    if form_data.auth_type == "bearer":
-                        token = form_data.key
-                    elif form_data.auth_type == "session":
-                        token = request.state.token.credentials
-                    elif form_data.auth_type == "system_oauth":
-                        try:
-                            if request.cookies.get("oauth_session_id", None):
-                                token = await request.app.state.oauth_manager.get_oauth_token(
+                client = MCPClient()
+                auth = None
+                headers = None
+
+                token = None
+                if form_data.auth_type == "bearer":
+                    token = form_data.key
+                elif form_data.auth_type == "session":
+                    token = request.state.token.credentials
+                elif form_data.auth_type == "system_oauth":
+                    try:
+                        if request.cookies.get("oauth_session_id", None):
+                            token = (
+                                await request.app.state.oauth_manager.get_oauth_token(
                                     user.id,
                                     request.cookies.get("oauth_session_id", None),
                                 )
-                        except Exception as e:
-                            pass
-
-                    if token:
-                        headers = {"Authorization": f"Bearer {token}"}
-
-                    await client.connect(form_data.url, auth=auth, headers=headers)
-                    specs = await client.list_tool_specs()
-                    return {
-                        "status": True,
-                        "specs": specs,
-                    }
+                            )
+                    except Exception as e:
+                        pass
+
+                if token:
+                    headers = {"Authorization": f"Bearer {token}"}
+
+                await client.connect(form_data.url, auth=auth, headers=headers)
+                specs = await client.list_tool_specs()
+                return {
+                    "status": True,
+                    "specs": specs,
+                }
             except Exception as e:
                 raise HTTPException(
                     status_code=400,
                     detail=f"Failed to create MCP client: {str(e)}",
                 )
+            finally:
+                if client:
+                    await client.disconnect()
         else:  # openapi
             token = None
             if form_data.auth_type == "bearer":

+ 18 - 20
backend/open_webui/utils/mcp/client.py

@@ -16,19 +16,25 @@ class MCPClient:
     async def connect(
         self, url: str, headers: Optional[dict] = None, auth: Optional[any] = None
     ):
-        self._streams_context = streamablehttp_client(url, headers=headers, auth=auth)
-        read_stream, write_stream, _ = (
-            await self._streams_context.__aenter__()
-        )  # pylint: disable=E1101
+        try:
+            self._streams_context = streamablehttp_client(
+                url, headers=headers, auth=auth
+            )
+
+            transport = await self.exit_stack.enter_async_context(self._streams_context)
+            read_stream, write_stream, _ = transport
 
-        self._session_context = ClientSession(
-            read_stream, write_stream
-        )  # pylint: disable=W0201
-        self.session: ClientSession = (
-            await self._session_context.__aenter__()
-        )  # pylint: disable=C2801
+            self._session_context = ClientSession(
+                read_stream, write_stream
+            )  # pylint: disable=W0201
 
-        await self.session.initialize()
+            self.session = await self.exit_stack.enter_async_context(
+                self._session_context
+            )
+            await self.session.initialize()
+        except Exception as e:
+            await self.disconnect()
+            raise e
 
     async def list_tool_specs(self) -> Optional[dict]:
         if not self.session:
@@ -97,15 +103,7 @@ class MCPClient:
 
     async def disconnect(self):
         # Clean up and close the session
-        if self.session:
-            await self._session_context.__aexit__(
-                None, None, None
-            )  # pylint: disable=E1101
-        if self._streams_context:
-            await self._streams_context.__aexit__(
-                None, None, None
-            )  # pylint: disable=E1101
-        self.session = None
+        await self.exit_stack.aclose()
 
     async def __aenter__(self):
         await self.exit_stack.__aenter__()