grpc_server.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. import grpc
  2. from concurrent import futures
  3. import numpy as np
  4. from asyncio import CancelledError
  5. import platform
  6. from . import node_service_pb2
  7. from . import node_service_pb2_grpc
  8. from exo import DEBUG
  9. from exo.inference.shard import Shard
  10. from exo.orchestration import Node
  11. import json
  12. if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
  13. import mlx.core as mx
  14. else:
  15. import numpy as mx
  16. class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
  17. def __init__(self, node: Node, host: str, port: int):
  18. self.node = node
  19. self.host = host
  20. self.port = port
  21. self.server = None
  22. async def start(self) -> None:
  23. self.server = grpc.aio.server(
  24. futures.ThreadPoolExecutor(max_workers=32),
  25. options=[
  26. ("grpc.max_metadata_size", 32*1024*1024),
  27. ("grpc.max_send_message_length", 256*1024*1024),
  28. ("grpc.max_receive_message_length", 256*1024*1024),
  29. ("grpc.keepalive_time_ms", 10000),
  30. ("grpc.keepalive_timeout_ms", 5000),
  31. ("grpc.http2.max_pings_without_data", 0),
  32. ("grpc.http2.min_time_between_pings_ms", 10000),
  33. ("grpc.http2.min_ping_interval_without_data_ms", 5000),
  34. ("grpc.max_concurrent_streams", 100),
  35. ("grpc.tcp_nodelay", 1),
  36. ("grpc.optimization_target", "throughput"),
  37. ],
  38. )
  39. node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)
  40. listen_addr = f"{self.host}:{self.port}"
  41. self.server.add_insecure_port(listen_addr)
  42. await self.server.start()
  43. if DEBUG >= 1: print(f"Server started, listening on {listen_addr}")
  44. async def stop(self) -> None:
  45. if self.server:
  46. try:
  47. await self.server.stop(grace=5)
  48. await self.server.wait_for_termination()
  49. except CancelledError:
  50. pass
  51. if DEBUG >= 1: print("Server stopped and all connections are closed")
  52. async def SendPrompt(self, request, context):
  53. shard = Shard(
  54. model_id=request.shard.model_id,
  55. start_layer=request.shard.start_layer,
  56. end_layer=request.shard.end_layer,
  57. n_layers=request.shard.n_layers,
  58. )
  59. prompt = request.prompt
  60. request_id = request.request_id
  61. inference_state = None if request.inference_state is None else self.deserialize_inference_state(request.inference_state)
  62. result = await self.node.process_prompt(shard, prompt, request_id, inference_state)
  63. if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
  64. tensor_data = result.tobytes() if result is not None else None
  65. return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
  66. async def SendTensor(self, request, context):
  67. shard = Shard(
  68. model_id=request.shard.model_id,
  69. start_layer=request.shard.start_layer,
  70. end_layer=request.shard.end_layer,
  71. n_layers=request.shard.n_layers,
  72. )
  73. tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
  74. request_id = request.request_id
  75. inference_state = None if request.inference_state is None else self.deserialize_inference_state(request.inference_state)
  76. result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
  77. if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
  78. tensor_data = result.tobytes() if result is not None else None
  79. return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
  80. async def SendExample(self, request, context):
  81. shard = Shard(
  82. model_id=request.shard.model_id,
  83. start_layer=request.shard.start_layer,
  84. end_layer=request.shard.end_layer,
  85. n_layers=request.shard.n_layers,
  86. )
  87. example = np.frombuffer(request.example.tensor_data, dtype=np.dtype(request.example.dtype)).reshape(request.example.shape)
  88. target = np.frombuffer(request.target.tensor_data, dtype=np.dtype(request.target.dtype)).reshape(request.target.shape)
  89. length = np.frombuffer(request.length.tensor_data, dtype=np.dtype(request.length.dtype)).reshape(request.length.shape)
  90. train = request.train
  91. request_id = request.request_id
  92. if train and not shard.is_first_layer():
  93. loss, grad = await self.node.process_example(shard, example, target, length, train, request_id)
  94. tensor_data = grad.tobytes()
  95. grad_tensor = node_service_pb2.Tensor(tensor_data=tensor_data, shape=grad.shape, dtype=str(grad.dtype))
  96. return node_service_pb2.Loss(loss=loss, grads=grad_tensor)
  97. else:
  98. loss = await self.node.process_example(shard, example, target, length, train, request_id)
  99. return node_service_pb2.Loss(loss=loss, grads=None)
  100. async def CollectTopology(self, request, context):
  101. max_depth = request.max_depth
  102. visited = set(request.visited)
  103. topology = self.node.current_topology
  104. nodes = {
  105. node_id:
  106. node_service_pb2.DeviceCapabilities(
  107. model=cap.model,
  108. chip=cap.chip,
  109. memory=cap.memory,
  110. flops=node_service_pb2.DeviceFlops(fp32=cap.flops.fp32, fp16=cap.flops.fp16, int8=cap.flops.int8),
  111. )
  112. for node_id, cap in topology.nodes.items()
  113. }
  114. peer_graph = {
  115. node_id: node_service_pb2.PeerConnections(connections=[node_service_pb2.PeerConnection(to_id=conn.to_id, description=conn.description) for conn in connections])
  116. for node_id, connections in topology.peer_graph.items()
  117. }
  118. if DEBUG >= 5: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
  119. return node_service_pb2.Topology(nodes=nodes, peer_graph=peer_graph)
  120. async def SendResult(self, request, context):
  121. request_id = request.request_id
  122. result = request.result
  123. is_finished = request.is_finished
  124. img = request.tensor
  125. if DEBUG >= 5: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}")
  126. result = list(result)
  127. if len(img.tensor_data) > 0:
  128. result = np.frombuffer(img.tensor_data, dtype=np.dtype(img.dtype)).reshape(img.shape)
  129. self.node.on_token.trigger_all(request_id, result, is_finished)
  130. return node_service_pb2.Empty()
  131. async def SendOpaqueStatus(self, request, context):
  132. request_id = request.request_id
  133. status = request.status
  134. if DEBUG >= 8: print(f"Received SendOpaqueStatus request: {request_id=} {status=}")
  135. self.node.on_opaque_status.trigger_all(request_id, status)
  136. return node_service_pb2.Empty()
  137. async def HealthCheck(self, request, context):
  138. return node_service_pb2.HealthCheckResponse(is_healthy=True)
  139. def deserialize_inference_state(self, inference_state_proto: node_service_pb2.InferenceState) -> dict:
  140. inference_state = {}
  141. for k, tensor_data in inference_state_proto.tensor_data.items():
  142. np_array = np.frombuffer(tensor_data.tensor_data, dtype=tensor_data.dtype).reshape(tensor_data.shape)
  143. inference_state[k] = mx.array(np_array)
  144. for k, tensor_list in inference_state_proto.tensor_list_data.items():
  145. inference_state[k] = [mx.array(np.frombuffer(tensor.tensor_data, dtype=tensor.dtype).reshape(tensor.shape)) for tensor in tensor_list.tensors]
  146. if inference_state_proto.other_data_json:
  147. other_data = json.loads(inference_state_proto.other_data_json)
  148. inference_state.update(other_data)
  149. return inference_state