grpc_peer_handle.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. import grpc
  2. import numpy as np
  3. import asyncio
  4. from typing import Optional, Tuple, List
  5. from . import node_service_pb2
  6. from . import node_service_pb2_grpc
  7. from ..peer_handle import PeerHandle
  8. from exo.inference.shard import Shard
  9. from exo.topology.topology import Topology
  10. from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
  11. from exo.helpers import DEBUG
  12. import json
  13. import mlx.core as mx
  14. class GRPCPeerHandle(PeerHandle):
  15. def __init__(self, _id: str, address: str, desc: str, device_capabilities: DeviceCapabilities):
  16. self._id = _id
  17. self.address = address
  18. self.desc = desc
  19. self._device_capabilities = device_capabilities
  20. self.channel = None
  21. self.stub = None
  22. def id(self) -> str:
  23. return self._id
  24. def addr(self) -> str:
  25. return self.address
  26. def description(self) -> str:
  27. return self.desc
  28. def device_capabilities(self) -> DeviceCapabilities:
  29. return self._device_capabilities
  30. async def connect(self):
  31. if self.channel is None:
  32. self.channel = grpc.aio.insecure_channel(self.address, options=[
  33. ("grpc.max_metadata_size", 32*1024*1024),
  34. ('grpc.max_receive_message_length', 32*1024*1024),
  35. ('grpc.max_send_message_length', 32*1024*1024)
  36. ])
  37. self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
  38. await self.channel.channel_ready()
  39. async def is_connected(self) -> bool:
  40. return self.channel is not None and self.channel.get_state() == grpc.ChannelConnectivity.READY
  41. async def disconnect(self):
  42. if self.channel:
  43. await self.channel.close()
  44. self.channel = None
  45. self.stub = None
  46. async def _ensure_connected(self):
  47. if not await self.is_connected(): await asyncio.wait_for(self.connect(), timeout=5)
  48. async def health_check(self) -> bool:
  49. try:
  50. await self._ensure_connected()
  51. request = node_service_pb2.HealthCheckRequest()
  52. response = await asyncio.wait_for(self.stub.HealthCheck(request), timeout=5)
  53. return response.is_healthy
  54. except asyncio.TimeoutError:
  55. return False
  56. except Exception:
  57. if DEBUG >= 4:
  58. print(f"Health check failed for {self._id}@{self.address}.")
  59. import traceback
  60. traceback.print_exc()
  61. return False
  62. async def send_prompt(self, shard: Shard, prompt: str, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
  63. request = node_service_pb2.PromptRequest(
  64. prompt=prompt,
  65. shard=node_service_pb2.Shard(
  66. model_id=shard.model_id,
  67. start_layer=shard.start_layer,
  68. end_layer=shard.end_layer,
  69. n_layers=shard.n_layers,
  70. ),
  71. request_id=request_id,
  72. inference_state=None if inference_state is None else self.serialize_inference_state(inference_state)
  73. )
  74. response = await self.stub.SendPrompt(request)
  75. if not response.tensor_data or not response.shape or not response.dtype:
  76. return None
  77. return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
  78. async def send_tensor(self, shard: Shard, tensor: np.ndarray, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
  79. request = node_service_pb2.TensorRequest(
  80. shard=node_service_pb2.Shard(
  81. model_id=shard.model_id,
  82. start_layer=shard.start_layer,
  83. end_layer=shard.end_layer,
  84. n_layers=shard.n_layers,
  85. ),
  86. tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
  87. request_id=request_id,
  88. inference_state=None if inference_state is None else self.serialize_inference_state(inference_state)
  89. )
  90. response = await self.stub.SendTensor(request)
  91. if not response.tensor_data or not response.shape or not response.dtype:
  92. return None
  93. return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
  94. async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarray, length: np.ndarray, train: bool, request_id: Optional[str] = None) -> Optional[np.array]:
  95. request = node_service_pb2.ExampleRequest(
  96. shard=node_service_pb2.Shard(
  97. model_id=shard.model_id,
  98. start_layer=shard.start_layer,
  99. end_layer=shard.end_layer,
  100. n_layers=shard.n_layers,
  101. ),
  102. example=node_service_pb2.Tensor(tensor_data=example.tobytes(), shape=example.shape, dtype=str(example.dtype)),
  103. target=node_service_pb2.Tensor(tensor_data=target.tobytes(), shape=target.shape, dtype=str(target.dtype)),
  104. length=node_service_pb2.Tensor(tensor_data=length.tobytes(), shape=length.shape, dtype=str(length.dtype)),
  105. train=train,
  106. request_id=request_id,
  107. )
  108. response = await self.stub.SendExample(request)
  109. loss = response.loss
  110. if train and not shard.is_first_layer():
  111. grads = np.frombuffer(response.grads.tensor_data, dtype=np.dtype(response.grads.dtype)).reshape(response.grads.shape)
  112. return loss, grads
  113. else:
  114. return loss
  115. async def send_loss(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
  116. request = node_service_pb2.TensorRequest(
  117. shard=node_service_pb2.Shard(
  118. model_id=shard.model_id,
  119. start_layer=shard.start_layer,
  120. end_layer=shard.end_layer,
  121. n_layers=shard.n_layers,
  122. ),
  123. tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
  124. request_id=request_id,
  125. )
  126. response = await self.stub.SendLoss(request)
  127. if not response.tensor_data or not response.shape or not response.dtype:
  128. return None
  129. return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
  130. async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
  131. request = node_service_pb2.GetInferenceResultRequest(request_id=request_id)
  132. response = await self.stub.GetInferenceResult(request)
  133. if response.tensor is None:
  134. return None, response.is_finished
  135. return (
  136. np.frombuffer(response.tensor.tensor_data, dtype=np.dtype(response.tensor.dtype)).reshape(response.tensor.shape),
  137. response.is_finished,
  138. )
  139. async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
  140. request = node_service_pb2.CollectTopologyRequest(visited=visited, max_depth=max_depth)
  141. response = await self.stub.CollectTopology(request)
  142. topology = Topology()
  143. for node_id, capabilities in response.nodes.items():
  144. device_capabilities = DeviceCapabilities(
  145. model=capabilities.model,
  146. chip=capabilities.chip,
  147. memory=capabilities.memory,
  148. flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
  149. )
  150. topology.update_node(node_id, device_capabilities)
  151. for node_id, peer_connections in response.peer_graph.items():
  152. for conn in peer_connections.connections:
  153. topology.add_edge(node_id, conn.to_id, conn.description)
  154. return topology
  155. async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
  156. tensor = None
  157. if isinstance(result, np.ndarray):
  158. tensor = node_service_pb2.Tensor(tensor_data=result.tobytes(), shape=result.shape, dtype=str(result.dtype))
  159. result = []
  160. request = node_service_pb2.SendResultRequest(request_id=request_id, result=result, tensor=tensor, is_finished=is_finished)
  161. await self.stub.SendResult(request)
  162. async def send_opaque_status(self, request_id: str, status: str) -> None:
  163. request = node_service_pb2.SendOpaqueStatusRequest(request_id=request_id, status=status)
  164. await self.stub.SendOpaqueStatus(request)
  165. def serialize_inference_state(self, inference_state: dict) -> node_service_pb2.InferenceState:
  166. proto_inference_state = node_service_pb2.InferenceState()
  167. other_data = {}
  168. for k, v in inference_state.items():
  169. if isinstance(v, mx.array):
  170. np_array = np.array(v)
  171. tensor_data = node_service_pb2.Tensor(
  172. tensor_data=np_array.tobytes(),
  173. shape=list(np_array.shape),
  174. dtype=str(np_array.dtype)
  175. )
  176. proto_inference_state.tensor_data[k].CopyFrom(tensor_data)
  177. elif isinstance(v, list) and all(isinstance(item, mx.array) for item in v):
  178. tensor_list = node_service_pb2.TensorList()
  179. for tensor in v:
  180. np_array = np.array(tensor)
  181. tensor_data = node_service_pb2.Tensor(
  182. tensor_data=np_array.tobytes(),
  183. shape=list(np_array.shape),
  184. dtype=str(np_array.dtype)
  185. )
  186. tensor_list.tensors.append(tensor_data)
  187. proto_inference_state.tensor_list_data[k].CopyFrom(tensor_list)
  188. else:
  189. # For non-tensor data, we'll still use JSON
  190. other_data[k] = v
  191. if other_data:
  192. proto_inference_state.other_data_json = json.dumps(other_data)
  193. return proto_inference_state