grpc_peer_handle.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import grpc
  2. import numpy as np
  3. from typing import Optional, Tuple, List
  4. # These would be generated from the .proto file
  5. from . import node_service_pb2
  6. from . import node_service_pb2_grpc
  7. from ..peer_handle import PeerHandle
  8. from exo.inference.shard import Shard
  9. from exo.topology.topology import Topology
  10. from exo.topology.device_capabilities import DeviceCapabilities
  11. class GRPCPeerHandle(PeerHandle):
  12. def __init__(self, id: str, address: str, device_capabilities: DeviceCapabilities):
  13. self._id = id
  14. self.address = address
  15. self._device_capabilities = device_capabilities
  16. self.channel = None
  17. self.stub = None
  18. def id(self) -> str:
  19. return self._id
  20. def device_capabilities(self) -> DeviceCapabilities:
  21. return self._device_capabilities
  22. async def connect(self):
  23. self.channel = grpc.aio.insecure_channel(self.address, options=[("grpc.max_metadata_size", 32 * 1024 * 1024)])
  24. self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
  25. async def is_connected(self) -> bool:
  26. return self.channel is not None and self.channel.get_state() == grpc.ChannelConnectivity.READY
  27. async def disconnect(self):
  28. if self.channel:
  29. await self.channel.close()
  30. self.channel = None
  31. self.stub = None
  32. async def send_prompt(
  33. self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None
  34. ) -> Optional[np.array]:
  35. request = node_service_pb2.PromptRequest(
  36. prompt=prompt,
  37. shard=node_service_pb2.Shard(
  38. model_id=shard.model_id,
  39. start_layer=shard.start_layer,
  40. end_layer=shard.end_layer,
  41. n_layers=shard.n_layers,
  42. ),
  43. request_id=request_id,
  44. inference_state=inference_state,
  45. )
  46. response = await self.stub.SendPrompt(request)
  47. if not response.tensor_data or not response.shape or not response.dtype:
  48. return None
  49. return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
  50. async def send_tensor(
  51. self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None
  52. ) -> Optional[np.array]:
  53. request = node_service_pb2.TensorRequest(
  54. shard=node_service_pb2.Shard(
  55. model_id=shard.model_id,
  56. start_layer=shard.start_layer,
  57. end_layer=shard.end_layer,
  58. n_layers=shard.n_layers,
  59. ),
  60. tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
  61. request_id=request_id,
  62. inference_state=inference_state,
  63. )
  64. response = await self.stub.SendTensor(request)
  65. if not response.tensor_data or not response.shape or not response.dtype:
  66. return None
  67. return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
  68. async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
  69. request = node_service_pb2.GetInferenceResultRequest(request_id=request_id)
  70. response = await self.stub.GetInferenceResult(request)
  71. if response.tensor is None:
  72. return None, response.is_finished
  73. return (
  74. np.frombuffer(response.tensor.tensor_data, dtype=np.dtype(response.tensor.dtype)).reshape(
  75. response.tensor.shape
  76. ),
  77. response.is_finished,
  78. )
  79. async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
  80. request = node_service_pb2.CollectTopologyRequest(visited=visited, max_depth=max_depth)
  81. response = await self.stub.CollectTopology(request)
  82. topology = Topology()
  83. for node_id, capabilities in response.nodes.items():
  84. device_capabilities = DeviceCapabilities(
  85. model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=capabilities.flops
  86. )
  87. topology.update_node(node_id, device_capabilities)
  88. for node_id, peers in response.peer_graph.items():
  89. for peer_id in peers.peer_ids:
  90. topology.add_edge(node_id, peer_id)
  91. return topology
  92. async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
  93. request = node_service_pb2.SendResultRequest(request_id=request_id, result=result, is_finished=is_finished)
  94. await self.stub.SendResult(request)
  95. async def send_opaque_status(self, request_id: str, status: str) -> None:
  96. request = node_service_pb2.SendOpaqueStatusRequest(request_id=request_id, status=status)
  97. await self.stub.SendOpaqueStatus(request)