client.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import asyncio
  2. from typing import Optional
  3. from contextlib import AsyncExitStack
  4. import anyio
  5. from mcp import ClientSession
  6. from mcp.client.auth import OAuthClientProvider, TokenStorage
  7. from mcp.client.streamable_http import streamablehttp_client
  8. from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
  9. class MCPClient:
  10. def __init__(self):
  11. self.session: Optional[ClientSession] = None
  12. self.exit_stack = None
  13. async def connect(self, url: str, headers: Optional[dict] = None):
  14. async with AsyncExitStack() as exit_stack:
  15. try:
  16. self._streams_context = streamablehttp_client(url, headers=headers)
  17. transport = await exit_stack.enter_async_context(self._streams_context)
  18. read_stream, write_stream, _ = transport
  19. self._session_context = ClientSession(
  20. read_stream, write_stream
  21. ) # pylint: disable=W0201
  22. self.session = await exit_stack.enter_async_context(
  23. self._session_context
  24. )
  25. with anyio.fail_after(10):
  26. await self.session.initialize()
  27. self.exit_stack = exit_stack.pop_all()
  28. except Exception as e:
  29. await asyncio.shield(self.disconnect())
  30. raise e
  31. async def list_tool_specs(self) -> Optional[dict]:
  32. if not self.session:
  33. raise RuntimeError("MCP client is not connected.")
  34. result = await self.session.list_tools()
  35. tools = result.tools
  36. tool_specs = []
  37. for tool in tools:
  38. name = tool.name
  39. description = tool.description
  40. inputSchema = tool.inputSchema
  41. # TODO: handle outputSchema if needed
  42. outputSchema = getattr(tool, "outputSchema", None)
  43. tool_specs.append(
  44. {"name": name, "description": description, "parameters": inputSchema}
  45. )
  46. return tool_specs
  47. async def call_tool(
  48. self, function_name: str, function_args: dict
  49. ) -> Optional[dict]:
  50. if not self.session:
  51. raise RuntimeError("MCP client is not connected.")
  52. result = await self.session.call_tool(function_name, function_args)
  53. if not result:
  54. raise Exception("No result returned from MCP tool call.")
  55. result_dict = result.model_dump(mode="json")
  56. result_content = result_dict.get("content", {})
  57. if result.isError:
  58. raise Exception(result_content)
  59. else:
  60. return result_content
  61. async def list_resources(self, cursor: Optional[str] = None) -> Optional[dict]:
  62. if not self.session:
  63. raise RuntimeError("MCP client is not connected.")
  64. result = await self.session.list_resources(cursor=cursor)
  65. if not result:
  66. raise Exception("No result returned from MCP list_resources call.")
  67. result_dict = result.model_dump()
  68. resources = result_dict.get("resources", [])
  69. return resources
  70. async def read_resource(self, uri: str) -> Optional[dict]:
  71. if not self.session:
  72. raise RuntimeError("MCP client is not connected.")
  73. result = await self.session.read_resource(uri)
  74. if not result:
  75. raise Exception("No result returned from MCP read_resource call.")
  76. result_dict = result.model_dump()
  77. return result_dict
  78. async def disconnect(self):
  79. # Clean up and close the session
  80. await self.exit_stack.aclose()
  81. async def __aenter__(self):
  82. await self.exit_stack.__aenter__()
  83. return self
  84. async def __aexit__(self, exc_type, exc_value, traceback):
  85. await self.exit_stack.__aexit__(exc_type, exc_value, traceback)
  86. await self.disconnect()