Browse Source

fix inference_state serialization. related: #40 #44 #45

Alex Cheema 1 year ago
parent
commit
1475c735c9

+ 3 - 8
exo/inference/tinygrad/inference.py

@@ -137,15 +137,10 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
         self.shard = None
         self.shard = None
 
 
     async def infer_prompt(self, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     async def infer_prompt(self, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
-        def encode_role(role: str):
-            return [self.tokenizer.special_tokens["<|start_header_id|>"]] + self.tokenizer.encode(role) + [self.tokenizer.special_tokens["<|end_header_id|>"]] + self.tokenizer.encode("\n\n")
-        def encode_message(role: str, content: str):
-            return encode_role(role) + self.tokenizer.encode(content.strip()) + [self.tokenizer.special_tokens["<|eot_id|>"]]
-
         await self.ensure_shard(shard)
         await self.ensure_shard(shard)
-        start_pos = json.loads(inference_state)["start_pos"] if inference_state else 0
+        start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0
 
 
-        toks = [self.tokenizer.bos_id] + encode_message("user", prompt) + encode_role("assistant")
+        toks = self.tokenizer.encode(prompt)
         start_pos = prefill(self.model, toks[:-1], start_pos=start_pos)
         start_pos = prefill(self.model, toks[:-1], start_pos=start_pos)
         last_tok = toks[-1]
         last_tok = toks[-1]
 
 
@@ -157,8 +152,8 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
 
 
     async def infer_tensor(self, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     async def infer_tensor(self, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
         await self.ensure_shard(shard)
         await self.ensure_shard(shard)
+        start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0
 
 
-        start_pos = json.loads(inference_state)["start_pos"] if inference_state else 0
         output_data: np.ndarray = np.array([self.model(Tensor([input_data]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist()])
         output_data: np.ndarray = np.array([self.model(Tensor([input_data]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist()])
         if output_data.size == 1:
         if output_data.size == 1:
            start_pos += 1
            start_pos += 1

+ 5 - 4
exo/networking/grpc/grpc_peer_handle.py

@@ -40,8 +40,8 @@ class GRPCPeerHandle(PeerHandle):
         self.channel = None
         self.channel = None
         self.stub = None
         self.stub = None
 
 
-    async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.array]:
-        request = node_service_pb2.PromptRequest(prompt=prompt, shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers), request_id=request_id)
+    async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
+        request = node_service_pb2.PromptRequest(prompt=prompt, shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers), request_id=request_id, inference_state=inference_state)
         response = await self.stub.SendPrompt(request)
         response = await self.stub.SendPrompt(request)
 
 
         if not response.tensor_data or not response.shape or not response.dtype:
         if not response.tensor_data or not response.shape or not response.dtype:
@@ -49,7 +49,7 @@ class GRPCPeerHandle(PeerHandle):
 
 
         return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
         return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
 
 
-    async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
+    async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
         request = node_service_pb2.TensorRequest(
         request = node_service_pb2.TensorRequest(
             shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers),
             shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers),
             tensor = node_service_pb2.Tensor(
             tensor = node_service_pb2.Tensor(
@@ -57,7 +57,8 @@ class GRPCPeerHandle(PeerHandle):
                 shape=tensor.shape,
                 shape=tensor.shape,
                 dtype=str(tensor.dtype)
                 dtype=str(tensor.dtype)
             ),
             ),
-            request_id=request_id
+            request_id=request_id,
+            inference_state=inference_state
         )
         )
         response = await self.stub.SendTensor(request)
         response = await self.stub.SendTensor(request)
 
 

+ 4 - 4
exo/networking/grpc/node_service.proto

@@ -23,15 +23,15 @@ message Shard {
 message PromptRequest {
 message PromptRequest {
   Shard shard = 1;
   Shard shard = 1;
   string prompt = 2;
   string prompt = 2;
-  optional string inference_state = 3;
-  optional string request_id = 4;
+  optional string request_id = 3;
+  optional string inference_state = 4;
 }
 }
 
 
 message TensorRequest {
 message TensorRequest {
   Shard shard = 1;
   Shard shard = 1;
   Tensor tensor = 2;
   Tensor tensor = 2;
-  optional string inference_state = 3;
-  optional string request_id = 4;
+  optional string request_id = 3;
+  optional string inference_state = 4;
 }
 }
 
 
 message GetInferenceResultRequest {
 message GetInferenceResultRequest {

File diff suppressed because it is too large
+ 0 - 0
exo/networking/grpc/node_service_pb2.py


+ 2 - 2
exo/networking/peer_handle.py

@@ -27,11 +27,11 @@ class PeerHandle(ABC):
         pass
         pass
 
 
     @abstractmethod
     @abstractmethod
-    async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.array]:
+    async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
         pass
         pass
 
 
     @abstractmethod
     @abstractmethod
-    async def send_tensor(self, shard: Shard, tensor: np.array, request_id: Optional[str] = None) -> Optional[np.array]:
+    async def send_tensor(self, shard: Shard, tensor: np.array, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
         pass
         pass
 
 
     @abstractmethod
     @abstractmethod

+ 2 - 2
exo/orchestration/standard_node.py

@@ -166,9 +166,9 @@ class StandardNode(Node):
             if DEBUG >= 1: print(f"Sending tensor_or_prompt to {target_peer.id()}: {tensor_or_prompt}")
             if DEBUG >= 1: print(f"Sending tensor_or_prompt to {target_peer.id()}: {tensor_or_prompt}")
 
 
             if isinstance(tensor_or_prompt, np.ndarray):
             if isinstance(tensor_or_prompt, np.ndarray):
-                await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id)
+                await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state)
             else:
             else:
-                await target_peer.send_prompt(next_shard, tensor_or_prompt, request_id)
+                await target_peer.send_prompt(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state)
 
 
     def get_current_shard(self, base_shard: Shard) -> Shard:
     def get_current_shard(self, base_shard: Shard) -> Shard:
         partitions = self.partitioning_strategy.partition(self.topology)
         partitions = self.partitioning_strategy.partition(self.topology)

Some files were not shown because too many files changed in this diff