Alex Cheema 5 months ago
parent
commit
fcc699a55f

+ 1 - 1
exo/inference/dummy_inference_engine.py

@@ -25,7 +25,7 @@ class DummyInferenceEngine(InferenceEngine):
   async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
     return self.tokenizer.decode(tokens)
 
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: dict = {}) -> tuple[np.ndarray, Optional[dict]]:
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
     await self.ensure_shard(shard)
     return input_data + 1 if self.shard.is_last_layer() else input_data, None
 

+ 2 - 2
exo/networking/grpc/grpc_peer_handle.py

@@ -82,7 +82,7 @@ class GRPCPeerHandle(PeerHandle):
         n_layers=shard.n_layers,
       ),
       request_id=request_id,
-      inference_state=self.serialize_inference_state(inference_state)
+      inference_state=None if inference_state is None else self.serialize_inference_state(inference_state)
     )
     response = await self.stub.SendPrompt(request)
 
@@ -101,7 +101,7 @@ class GRPCPeerHandle(PeerHandle):
       ),
       tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
       request_id=request_id,
-      inference_state=self.serialize_inference_state(inference_state)
+      inference_state=None if inference_state is None else self.serialize_inference_state(inference_state)
     )
     response = await self.stub.SendTensor(request)
 

+ 2 - 2
exo/networking/grpc/grpc_server.py

@@ -52,7 +52,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     )
     prompt = request.prompt
     request_id = request.request_id
-    inference_state = self.deserialize_inference_state(request.inference_state)
+    inference_state = None if request.inference_state is None else self.deserialize_inference_state(request.inference_state)
     result = await self.node.process_prompt(shard, prompt, request_id, inference_state)
     if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
     tensor_data = result.tobytes() if result is not None else None
@@ -68,7 +68,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
     request_id = request.request_id
 
-    inference_state = self.deserialize_inference_state(request.inference_state)
+    inference_state = None if request.inference_state is None else self.deserialize_inference_state(request.inference_state)
 
     result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
     if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")