grpc_peer_handle.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import grpc
  2. import numpy as np
  3. import asyncio
  4. from typing import Optional, Tuple, List
  5. from . import node_service_pb2
  6. from . import node_service_pb2_grpc
  7. from ..peer_handle import PeerHandle
  8. from exo.inference.shard import Shard
  9. from exo.topology.topology import Topology
  10. from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
  11. from exo.helpers import DEBUG
  12. class GRPCPeerHandle(PeerHandle):
  13. def __init__(self, _id: str, address: str, device_capabilities: DeviceCapabilities):
  14. self._id = _id
  15. self.address = address
  16. self._device_capabilities = device_capabilities
  17. self.channel = None
  18. self.stub = None
  19. def id(self) -> str:
  20. return self._id
  21. def addr(self) -> str:
  22. return self.address
  23. def device_capabilities(self) -> DeviceCapabilities:
  24. return self._device_capabilities
  25. async def connect(self):
  26. if self.channel is None:
  27. self.channel = grpc.aio.insecure_channel(self.address, options=[
  28. ("grpc.max_metadata_size", 32*1024*1024),
  29. ('grpc.max_receive_message_length', 32*1024*1024),
  30. ('grpc.max_send_message_length', 32*1024*1024)
  31. ])
  32. self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
  33. await self.channel.channel_ready()
  34. async def is_connected(self) -> bool:
  35. return self.channel is not None and self.channel.get_state() == grpc.ChannelConnectivity.READY
  36. async def disconnect(self):
  37. if self.channel:
  38. await self.channel.close()
  39. self.channel = None
  40. self.stub = None
  41. async def _ensure_connected(self):
  42. if not await self.is_connected(): await asyncio.wait_for(self.connect(), timeout=5)
  43. async def health_check(self) -> bool:
  44. try:
  45. await self._ensure_connected()
  46. request = node_service_pb2.HealthCheckRequest()
  47. response = await asyncio.wait_for(self.stub.HealthCheck(request), timeout=5)
  48. return response.is_healthy
  49. except asyncio.TimeoutError:
  50. return False
  51. except Exception:
  52. if DEBUG >= 4:
  53. print(f"Health check failed for {self._id}@{self.address}.")
  54. import traceback
  55. traceback.print_exc()
  56. return False
  57. async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.array]:
  58. request = node_service_pb2.PromptRequest(
  59. prompt=prompt,
  60. shard=node_service_pb2.Shard(
  61. model_id=shard.model_id,
  62. start_layer=shard.start_layer,
  63. end_layer=shard.end_layer,
  64. n_layers=shard.n_layers,
  65. ),
  66. request_id=request_id,
  67. )
  68. response = await self.stub.SendPrompt(request)
  69. if not response.tensor_data or not response.shape or not response.dtype:
  70. return None
  71. return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
  72. async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
  73. request = node_service_pb2.TensorRequest(
  74. shard=node_service_pb2.Shard(
  75. model_id=shard.model_id,
  76. start_layer=shard.start_layer,
  77. end_layer=shard.end_layer,
  78. n_layers=shard.n_layers,
  79. ),
  80. tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
  81. request_id=request_id,
  82. )
  83. response = await self.stub.SendTensor(request)
  84. if not response.tensor_data or not response.shape or not response.dtype:
  85. return None
  86. return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
  87. async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
  88. request = node_service_pb2.GetInferenceResultRequest(request_id=request_id)
  89. response = await self.stub.GetInferenceResult(request)
  90. if response.tensor is None:
  91. return None, response.is_finished
  92. return (
  93. np.frombuffer(response.tensor.tensor_data, dtype=np.dtype(response.tensor.dtype)).reshape(response.tensor.shape),
  94. response.is_finished,
  95. )
  96. async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
  97. request = node_service_pb2.CollectTopologyRequest(visited=visited, max_depth=max_depth)
  98. response = await self.stub.CollectTopology(request)
  99. topology = Topology()
  100. for node_id, capabilities in response.nodes.items():
  101. device_capabilities = DeviceCapabilities(
  102. model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
  103. )
  104. topology.update_node(node_id, device_capabilities)
  105. for node_id, peers in response.peer_graph.items():
  106. for peer_id in peers.peer_ids:
  107. topology.add_edge(node_id, peer_id)
  108. return topology
  109. async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
  110. request = node_service_pb2.SendResultRequest(request_id=request_id, result=result, is_finished=is_finished)
  111. await self.stub.SendResult(request)
  112. async def send_opaque_status(self, request_id: str, status: str) -> None:
  113. request = node_service_pb2.SendOpaqueStatusRequest(request_id=request_id, status=status)
  114. await self.stub.SendOpaqueStatus(request)