|
@@ -52,7 +52,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
|
)
|
|
)
|
|
prompt = request.prompt
|
|
prompt = request.prompt
|
|
request_id = request.request_id
|
|
request_id = request.request_id
|
|
- inference_state = self.deserialize_inference_state(request.inference_state)
|
|
|
|
|
|
+ inference_state = None if request.inference_state is None else self.deserialize_inference_state(request.inference_state)
|
|
result = await self.node.process_prompt(shard, prompt, request_id, inference_state)
|
|
result = await self.node.process_prompt(shard, prompt, request_id, inference_state)
|
|
if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
|
|
if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
|
|
tensor_data = result.tobytes() if result is not None else None
|
|
tensor_data = result.tobytes() if result is not None else None
|
|
@@ -68,7 +68,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
|
tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
|
|
tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
|
|
request_id = request.request_id
|
|
request_id = request.request_id
|
|
|
|
|
|
- inference_state = self.deserialize_inference_state(request.inference_state)
|
|
|
|
|
|
+ inference_state = None if request.inference_state is None else self.deserialize_inference_state(request.inference_state)
|
|
|
|
|
|
result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
|
|
result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
|
|
if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
|
|
if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
|