Alex Cheema преди 1 година
родител
ревизия
03ba31c020

+ 3 - 4
exo/inference/inference_engine.py

@@ -1,17 +1,16 @@
 import numpy as np
-import mlx.nn as nn
 
-from typing import Tuple
+from typing import Tuple, Optional
 from abc import ABC, abstractmethod
 from .shard import Shard
 
 class InferenceEngine(ABC):
     @abstractmethod
-    async def infer_tensor(self, shard: Shard, input_data: np.ndarray) -> Tuple[np.ndarray, bool]:
+    async def infer_tensor(self, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
         pass
 
     @abstractmethod
-    async def infer_prompt(self, shard: Shard, prompt: str) -> Tuple[np.ndarray, bool]:
+    async def infer_prompt(self, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
         pass
 
     @abstractmethod

+ 4 - 3
exo/inference/mlx/sharded_inference_engine.py

@@ -4,6 +4,7 @@ from ..inference_engine import InferenceEngine
 from .sharded_model import StatefulShardedModel
 from .sharded_utils import load_shard
 from ..shard import Shard
+from typing import Optional
 
 class MLXFixedShardInferenceEngine(InferenceEngine):
     def __init__(self, model_path: str, shard: Shard):
@@ -12,7 +13,7 @@ class MLXFixedShardInferenceEngine(InferenceEngine):
         model_shard, self.tokenizer = load_shard(model_path, shard)
         self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
 
-    async def infer_prompt(self, shard: Shard, prompt: str) -> (np.ndarray, bool):
+    async def infer_prompt(self, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, bool):
         if shard != self.shard:
             raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
 
@@ -38,12 +39,12 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     def __init__(self):
         self.shard = None
 
-    async def infer_prompt(self, shard: Shard, prompt: str) -> (np.ndarray, bool):
+    async def infer_prompt(self, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, bool):
         await self.ensure_shard(shard)
         output_data: np.ndarray = np.array(self.stateful_sharded_model.step(mx.array(self.tokenizer.encode(prompt))))
         return output_data, output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
 
-    async def infer_tensor(self, shard: Shard, input_data: np.ndarray) -> (np.ndarray, bool):
+    async def infer_tensor(self, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, bool):
         await self.ensure_shard(shard)
         output_data: np.ndarray = np.array(self.stateful_sharded_model.step(mx.array(input_data)))
         return output_data, output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id

+ 5 - 3
exo/inference/test_inference_engine.py

@@ -7,15 +7,17 @@ import numpy as np
 # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
 async def test_inference_engine(inference_engine: InferenceEngine, model_id: str, input_data: np.array):
     # inference_engine.reset_shard(Shard("", 0,0,0))
-    resp_full, _ = await inference_engine.infer_prompt(shard=Shard(model_id=model_id, start_layer=0, end_layer=1, n_layers=2), prompt="In one word, what is the capital of USA? ")
+    prompt = "In a single word only, what is the capital of Japan? "
+    resp_full, _, _ = await inference_engine.infer_prompt(shard=Shard(model_id=model_id, start_layer=0, end_layer=1, n_layers=2), prompt=prompt)
 
     print("resp_full", resp_full)
     print("decoded", inference_engine.tokenizer.decode(resp_full))
 
     # inference_engine.reset_shard(Shard("", 0,0,0))
 
-    # resp1, _ = await inference_engine.infer_tensor(shard=Shard(model_id=model_id, start_layer=0, end_layer=0, n_layers=2), input_data=input_data)
-    # resp2, _ = await inference_engine.infer_tensor(shard=Shard(model_id=model_id, start_layer=1, end_layer=1, n_layers=2), input_data=resp1)
+    resp1, inference_state, _ = await inference_engine.infer_tensor(shard=Shard(model_id=model_id, start_layer=0, end_layer=0, n_layers=2), input_data=input_data)
+    print(f"Intermediate {inference_state=}")
+    resp2, _, _ = await inference_engine.infer_tensor(shard=Shard(model_id=model_id, start_layer=1, end_layer=1, n_layers=2), input_data=resp1, inference_state=inference_state)
 
     # assert np.array_equal(resp_full, resp2)
 

+ 15 - 7
exo/inference/tinygrad/inference.py

@@ -1,6 +1,6 @@
 
 from pathlib import Path
-from typing import List
+from typing import List, Optional
 import json, argparse, random, time
 import tiktoken
 from tiktoken.load import load_tiktoken_bpe
@@ -147,7 +147,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     def __init__(self):
         self.shard = None
 
-    async def infer_prompt(self, shard: Shard, prompt: str) -> (np.ndarray, bool):
+    async def infer_prompt(self, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
         def encode_role(role: str):
             return [self.tokenizer.special_tokens["<|start_header_id|>"]] + self.tokenizer.encode(role) + [self.tokenizer.special_tokens["<|end_header_id|>"]] + self.tokenizer.encode("\n\n")
         def encode_message(role: str, content: str):
@@ -161,14 +161,22 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
         last_tok = toks[-1]
 
         output_data = np.array(self.model(Tensor([[last_tok]]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist())
-        start_pos += 1
+        print(f"{output_data.size=}")
+        if output_data.size == 1:
+           start_pos += 1
 
-        return output_data, output_data.size == 1 and output_data.item() in self.tokenizer.stop_tokens
+        return output_data, json.dumps({"start_pos": start_pos}), output_data.size == 1 and output_data.item() in self.tokenizer.stop_tokens
 
-    async def infer_tensor(self, shard: Shard, input_data: np.ndarray) -> (np.ndarray, bool):
+    async def infer_tensor(self, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
         await self.ensure_shard(shard)
-        output_data: np.ndarray = np.array(self.model(Tensor([input_data]), 0, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist())
-        return output_data, output_data.size == 1 and output_data.item() in self.tokenizer.stop_tokens
+
+        start_pos = json.loads(inference_state)["start_pos"] if inference_state else 0
+        output_data: np.ndarray = np.array(self.model(Tensor([input_data]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist())
+        print(f"{output_data.size=}")
+        if output_data.size == 1:
+           start_pos += 1
+
+        return output_data, json.dumps({"start_pos": start_pos}), output_data.size == 1 and output_data.item() in self.tokenizer.stop_tokens
 
     async def reset_shard(self, shard: Shard):
         await self.ensure_shard(shard)

+ 2 - 1
exo/networking/grpc/grpc_server.py

@@ -44,8 +44,9 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
         shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
         tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
         request_id = request.request_id
+        inference_state = request.inference_state
 
-        result = await self.node.process_tensor(shard, tensor, request_id)
+        result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
         if DEBUG >= 2: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
         tensor_data = result.tobytes() if result is not None else None
         return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()

+ 4 - 2
exo/networking/grpc/node_service.proto

@@ -21,13 +21,15 @@ message Shard {
 message PromptRequest {
   Shard shard = 1;
   string prompt = 2;
-  optional string request_id = 3;
+  optional string inference_state = 3;
+  optional string request_id = 4;
 }
 
 message TensorRequest {
   Shard shard = 1;
   Tensor tensor = 2;
-  optional string request_id = 3;
+  optional string inference_state = 3;
+  optional string request_id = 4;
 }
 
 message GetInferenceResultRequest {

Файловите разлики са ограничени, защото са твърде много
+ 0 - 0
exo/networking/grpc/node_service_pb2.py


+ 2 - 2
exo/orchestration/node.py

@@ -14,11 +14,11 @@ class Node(ABC):
         pass
 
     @abstractmethod
-    async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
+    async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
         pass
 
     @abstractmethod
-    async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.ndarray]:
+    async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
         pass
 
     @abstractmethod

+ 9 - 9
exo/orchestration/standard_node.py

@@ -37,7 +37,7 @@ class StandardNode(Node):
         await self.discovery.stop()
         await self.server.stop()
 
-    async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
+    async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
         if request_id is None:
             request_id = str(uuid.uuid4())
         if request_id not in self.buffered_token_output:
@@ -49,7 +49,7 @@ class StandardNode(Node):
             await self.forward_to_next_shard(shard, prompt, request_id)
             return
 
-        result, is_finished = await self.inference_engine.infer_prompt(self.get_current_shard(shard), prompt)
+        result, inference_state, is_finished = await self.inference_engine.infer_prompt(self.get_current_shard(shard), prompt, inference_state=inference_state)
         is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
         if is_finished:
             self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
@@ -61,11 +61,11 @@ class StandardNode(Node):
         if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id])}")
 
         if not is_finished:
-            asyncio.create_task(self.forward_to_next_shard(shard, result, request_id))
+            asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state))
 
         return np.array(self.buffered_token_output[request_id]) if len(self.buffered_token_output[request_id]) > 0 else None
 
-    async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.ndarray]:
+    async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
         if request_id is None:
             request_id = str(uuid.uuid4())
         if request_id not in self.buffered_token_output:
@@ -73,7 +73,7 @@ class StandardNode(Node):
 
         try:
             if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
-            result, is_finished = await self.inference_engine.infer_tensor(self.get_current_shard(shard), tensor)
+            result, inference_state, is_finished = await self.inference_engine.infer_tensor(self.get_current_shard(shard), tensor, inference_state=inference_state)
             is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
             if is_finished:
                 self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
@@ -84,7 +84,7 @@ class StandardNode(Node):
             if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id])}")
 
             if not is_finished:
-                asyncio.create_task(self.forward_to_next_shard(shard, result, request_id))
+                asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state))
 
             return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
         except Exception as e:
@@ -93,7 +93,7 @@ class StandardNode(Node):
             traceback.print_exc()
             return None
 
-    async def forward_to_next_shard(self, shard: Shard, tensor_or_prompt: Union[np.ndarray, str], request_id: str) -> None:
+    async def forward_to_next_shard(self, shard: Shard, tensor_or_prompt: Union[np.ndarray, str], request_id: str, inference_state: Optional[str] = None) -> None:
         if not self.partitioning_strategy:
             if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
             return
@@ -109,9 +109,9 @@ class StandardNode(Node):
             if next_partition:
                 if next_partition.node_id == self.id:
                     if isinstance(tensor_or_prompt, np.ndarray):
-                        await self.process_tensor(shard, tensor_or_prompt, request_id)
+                        await self.process_tensor(shard, tensor_or_prompt, request_id, inference_state=inference_state)
                     else:
-                        await self.process_prompt(shard, tensor_or_prompt, request_id)
+                        await self.process_prompt(shard, tensor_or_prompt, request_id, inference_state=inference_state)
                     return
 
                 target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None)

Някои файлове не бяха показани, защото твърде много файлове са промени