1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071 |
- import grpc
- from concurrent import futures
- import numpy as np
- from . import node_service_pb2
- from . import node_service_pb2_grpc
- from inference.shard import Shard
- from orchestration import Node
- import uuid
- class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
- def __init__(self, node: Node, host: str, port: int):
- self.node = node
- self.host = host
- self.port = port
- self.server = None
- async def start(self) -> None:
- self.server = grpc.aio.server(futures.ThreadPoolExecutor(max_workers=10), options=[
- ('grpc.max_metadata_size', 32*1024*1024)
- ])
- node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)
- listen_addr = f'{self.host}:{self.port}'
- self.server.add_insecure_port(listen_addr)
- await self.server.start()
- print(f"Server started, listening on {listen_addr}")
- async def stop(self) -> None:
- if self.server:
- await self.server.stop(grace=5)
- await self.server.wait_for_termination()
- print("Server stopped and all connections are closed")
- async def SendPrompt(self, request, context):
- shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
- prompt = request.prompt
- request_id = request.request_id
- result = await self.node.process_prompt(shard, prompt, request_id)
- tensor_data = result.tobytes() if result is not None else None
- return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
- async def SendTensor(self, request, context):
- shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
- tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
- request_id = request.request_id
- result = await self.node.process_tensor(shard, tensor, request_id)
- print("SendTensor tensor result", result)
- tensor_data = result.tobytes() if result is not None else None
- return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
- async def GetInferenceResult(self, request, context):
- request_id = request.request_id
- result = await self.node.get_inference_result(request_id)
- tensor_data = result[0].tobytes() if result[0] is not None else None
- return node_service_pb2.InferenceResult(tensor=node_service_pb2.Tensor(tensor_data=tensor_data, shape=result[0].shape, dtype=str(result[0].dtype))) if result[0] is not None else node_service_pb2.InferenceResult()
- async def ResetShard(self, request, context):
- shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
- print(f"Received ResetShard request: {shard}")
- await self.node.reset_shard(shard)
- return node_service_pb2.Empty()
- async def CollectTopology(self, request, context):
- max_depth = request.max_depth
- topology = await self.node.collect_topology(max_depth)
- nodes = {node_id: node_service_pb2.DeviceCapabilities(model=cap.model, chip=cap.chip, memory=cap.memory) for node_id, cap in topology.nodes.items()}
- peer_graph = {node_id: node_service_pb2.Peers(peer_ids=peers) for node_id, peers in topology.peer_graph.items()}
- return node_service_pb2.Topology(nodes=nodes, peer_graph=peer_graph)
|