grpc_server.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. import grpc
  2. from concurrent import futures
  3. import numpy as np
  4. from . import node_service_pb2
  5. from . import node_service_pb2_grpc
  6. from inference.shard import Shard
  7. from orchestration import Node
  8. class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
  9. def __init__(self, node: Node, host: str, port: int):
  10. self.node = node
  11. self.host = host
  12. self.port = port
  13. self.server = None
  14. async def start(self) -> None:
  15. self.server = grpc.aio.server(futures.ThreadPoolExecutor(max_workers=10))
  16. node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)
  17. listen_addr = f'{self.host}:{self.port}'
  18. self.server.add_insecure_port(listen_addr)
  19. await self.server.start()
  20. print(f"Server started, listening on {listen_addr}")
  21. async def stop(self) -> None:
  22. if self.server:
  23. await self.server.stop(5) # 5 seconds grace period
  24. print("Server stopped")
  25. async def SendPrompt(self, request, context):
  26. shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
  27. prompt = request.prompt
  28. result = await self.node.process_prompt(shard, prompt)
  29. tensor_data = result.tobytes() if result is not None else None
  30. return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype))
  31. async def SendTensor(self, request, context):
  32. shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
  33. tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
  34. result = await self.node.process_tensor(shard, tensor)
  35. print("SendTensor tensor result", result)
  36. tensor_data = result.tobytes() if result is not None else None
  37. return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype))
  38. async def ResetShard(self, request, context):
  39. shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
  40. print(f"Received ResetShard request: {shard}")
  41. await self.node.reset_shard(shard)
  42. return node_service_pb2.Empty()