|
@@ -16,19 +16,25 @@ class MCPClient:
|
|
|
async def connect(
|
|
|
self, url: str, headers: Optional[dict] = None, auth: Optional[any] = None
|
|
|
):
|
|
|
- self._streams_context = streamablehttp_client(url, headers=headers, auth=auth)
|
|
|
- read_stream, write_stream, _ = (
|
|
|
- await self._streams_context.__aenter__()
|
|
|
- ) # pylint: disable=E1101
|
|
|
+ try:
|
|
|
+ self._streams_context = streamablehttp_client(
|
|
|
+ url, headers=headers, auth=auth
|
|
|
+ )
|
|
|
+
|
|
|
+ transport = await self.exit_stack.enter_async_context(self._streams_context)
|
|
|
+ read_stream, write_stream, _ = transport
|
|
|
|
|
|
- self._session_context = ClientSession(
|
|
|
- read_stream, write_stream
|
|
|
- ) # pylint: disable=W0201
|
|
|
- self.session: ClientSession = (
|
|
|
- await self._session_context.__aenter__()
|
|
|
- ) # pylint: disable=C2801
|
|
|
+ self._session_context = ClientSession(
|
|
|
+ read_stream, write_stream
|
|
|
+ ) # pylint: disable=W0201
|
|
|
|
|
|
- await self.session.initialize()
|
|
|
+ self.session = await self.exit_stack.enter_async_context(
|
|
|
+ self._session_context
|
|
|
+ )
|
|
|
+ await self.session.initialize()
|
|
|
+ except Exception as e:
|
|
|
+ await self.disconnect()
|
|
|
+ raise e
|
|
|
|
|
|
async def list_tool_specs(self) -> Optional[dict]:
|
|
|
if not self.session:
|
|
@@ -97,15 +103,7 @@ class MCPClient:
|
|
|
|
|
|
async def disconnect(self):
|
|
|
# Clean up and close the session
|
|
|
- if self.session:
|
|
|
- await self._session_context.__aexit__(
|
|
|
- None, None, None
|
|
|
- ) # pylint: disable=E1101
|
|
|
- if self._streams_context:
|
|
|
- await self._streams_context.__aexit__(
|
|
|
- None, None, None
|
|
|
- ) # pylint: disable=E1101
|
|
|
- self.session = None
|
|
|
+ await self.exit_stack.aclose()
|
|
|
|
|
|
async def __aenter__(self):
|
|
|
await self.exit_stack.__aenter__()
|