|
@@ -4,12 +4,10 @@ import numpy as np
|
|
|
|
|
|
from . import node_service_pb2
|
|
|
from . import node_service_pb2_grpc
|
|
|
+from exo import DEBUG
|
|
|
from exo.inference.shard import Shard
|
|
|
-
|
|
|
from exo.orchestration import Node
|
|
|
|
|
|
-import uuid
|
|
|
-
|
|
|
class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
|
|
def __init__(self, node: Node, host: str, port: int):
|
|
|
self.node = node
|
|
@@ -25,19 +23,20 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
|
|
listen_addr = f'{self.host}:{self.port}'
|
|
|
self.server.add_insecure_port(listen_addr)
|
|
|
await self.server.start()
|
|
|
- print(f"Server started, listening on {listen_addr}")
|
|
|
+ if DEBUG >= 1: print(f"Server started, listening on {listen_addr}")
|
|
|
|
|
|
async def stop(self) -> None:
|
|
|
if self.server:
|
|
|
await self.server.stop(grace=5)
|
|
|
await self.server.wait_for_termination()
|
|
|
- print("Server stopped and all connections are closed")
|
|
|
+ if DEBUG >= 1: print("Server stopped and all connections are closed")
|
|
|
|
|
|
async def SendPrompt(self, request, context):
|
|
|
shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
|
|
|
prompt = request.prompt
|
|
|
request_id = request.request_id
|
|
|
result = await self.node.process_prompt(shard, prompt, request_id)
|
|
|
+ if DEBUG >= 2: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
|
|
|
tensor_data = result.tobytes() if result is not None else None
|
|
|
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
|
|
|
|
|
@@ -47,19 +46,20 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
|
|
request_id = request.request_id
|
|
|
|
|
|
result = await self.node.process_tensor(shard, tensor, request_id)
|
|
|
- print("SendTensor tensor result", result)
|
|
|
+ if DEBUG >= 2: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
|
|
|
tensor_data = result.tobytes() if result is not None else None
|
|
|
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
|
|
|
|
|
|
async def GetInferenceResult(self, request, context):
|
|
|
request_id = request.request_id
|
|
|
result = await self.node.get_inference_result(request_id)
|
|
|
+ if DEBUG >= 2: print(f"GetInferenceResult {request_id=}: {result}")
|
|
|
tensor_data = result[0].tobytes() if result[0] is not None else None
|
|
|
return node_service_pb2.InferenceResult(tensor=node_service_pb2.Tensor(tensor_data=tensor_data, shape=result[0].shape, dtype=str(result[0].dtype)), is_finished=result[1]) if result[0] is not None else node_service_pb2.InferenceResult(is_finished=result[1])
|
|
|
|
|
|
async def ResetShard(self, request, context):
|
|
|
shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
|
|
|
- print(f"Received ResetShard request: {shard}")
|
|
|
+ if DEBUG >= 2: print(f"Received ResetShard request: {shard}")
|
|
|
await self.node.reset_shard(shard)
|
|
|
return node_service_pb2.Empty()
|
|
|
|