grpc_server.py 3.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import grpc
  2. from concurrent import futures
  3. import numpy as np
  4. from . import node_service_pb2
  5. from . import node_service_pb2_grpc
  6. from inference.shard import Shard
  7. from orchestration import Node
  8. import uuid
  9. class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
  10. def __init__(self, node: Node, host: str, port: int):
  11. self.node = node
  12. self.host = host
  13. self.port = port
  14. self.server = None
  15. async def start(self) -> None:
  16. self.server = grpc.aio.server(futures.ThreadPoolExecutor(max_workers=10), options=[
  17. ('grpc.max_metadata_size', 32*1024*1024)
  18. ])
  19. node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)
  20. listen_addr = f'{self.host}:{self.port}'
  21. self.server.add_insecure_port(listen_addr)
  22. await self.server.start()
  23. print(f"Server started, listening on {listen_addr}")
  24. async def stop(self) -> None:
  25. if self.server:
  26. await self.server.stop(grace=5)
  27. await self.server.wait_for_termination()
  28. print("Server stopped and all connections are closed")
  29. async def SendPrompt(self, request, context):
  30. shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
  31. prompt = request.prompt
  32. request_id = request.request_id
  33. result = await self.node.process_prompt(shard, prompt, request_id)
  34. tensor_data = result.tobytes() if result is not None else None
  35. return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
  36. async def SendTensor(self, request, context):
  37. shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
  38. tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
  39. request_id = request.request_id
  40. result = await self.node.process_tensor(shard, tensor, request_id)
  41. print("SendTensor tensor result", result)
  42. tensor_data = result.tobytes() if result is not None else None
  43. return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
  44. async def GetInferenceResult(self, request, context):
  45. request_id = request.request_id
  46. result = await self.node.get_inference_result(request_id)
  47. tensor_data = result[0].tobytes() if result[0] is not None else None
  48. return node_service_pb2.InferenceResult(tensor=node_service_pb2.Tensor(tensor_data=tensor_data, shape=result[0].shape, dtype=str(result[0].dtype)), is_finished=result[1]) if result[0] is not None else node_service_pb2.InferenceResult(is_finished=result[1])
  49. async def ResetShard(self, request, context):
  50. shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
  51. print(f"Received ResetShard request: {shard}")
  52. await self.node.reset_shard(shard)
  53. return node_service_pb2.Empty()
  54. async def CollectTopology(self, request, context):
  55. max_depth = request.max_depth
  56. topology = await self.node.collect_topology(max_depth)
  57. nodes = {node_id: node_service_pb2.DeviceCapabilities(model=cap.model, chip=cap.chip, memory=cap.memory) for node_id, cap in topology.nodes.items()}
  58. peer_graph = {node_id: node_service_pb2.Peers(peer_ids=peers) for node_id, peers in topology.peer_graph.items()}
  59. return node_service_pb2.Topology(nodes=nodes, peer_graph=peer_graph)