|
@@ -40,8 +40,8 @@ class GRPCPeerHandle(PeerHandle):
|
|
|
self.channel = None
|
|
|
self.stub = None
|
|
|
|
|
|
- async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.array]:
|
|
|
- request = node_service_pb2.PromptRequest(prompt=prompt, shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers), request_id=request_id)
|
|
|
+ async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
|
|
|
+ request = node_service_pb2.PromptRequest(prompt=prompt, shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers), request_id=request_id, inference_state=inference_state)
|
|
|
response = await self.stub.SendPrompt(request)
|
|
|
|
|
|
if not response.tensor_data or not response.shape or not response.dtype:
|
|
@@ -49,7 +49,7 @@ class GRPCPeerHandle(PeerHandle):
|
|
|
|
|
|
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
|
|
|
|
|
|
- async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
|
|
|
+ async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
|
|
|
request = node_service_pb2.TensorRequest(
|
|
|
shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers),
|
|
|
tensor = node_service_pb2.Tensor(
|
|
@@ -57,7 +57,8 @@ class GRPCPeerHandle(PeerHandle):
|
|
|
shape=tensor.shape,
|
|
|
dtype=str(tensor.dtype)
|
|
|
),
|
|
|
- request_id=request_id
|
|
|
+ request_id=request_id,
|
|
|
+ inference_state=inference_state
|
|
|
)
|
|
|
response = await self.stub.SendTensor(request)
|
|
|
|