Browse Source

Fixing tinygrad model

Nel Nibcord 7 months ago
parent
commit
37a75d6b96

+ 4 - 7
exo/inference/mlx/sharded_inference_engine.py

@@ -44,7 +44,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     self.caches = OrderedDict()
     self.session = {}
 
-  async def poll_cache(self, request_id: str, max_caches=2):
+  async def poll_state(self, request_id: str, max_caches=2):
     if request_id in self.caches:
       self.caches.move_to_end(request_id)
     else:
@@ -52,7 +52,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
       if len(self.caches) > max_caches:
         self.caches.popitem(last=False)
       self.caches[request_id] = newcache
-    return self.caches[request_id]
+    return {"cache": self.caches[request_id]}
 
   async def sample(self, x, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
     y = mx.array(x)
@@ -80,13 +80,10 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
     await self.ensure_shard(shard)
-    #print(f"infer_tensor in <- {input_data}")
     loop = asyncio.get_running_loop()
-    cache = await self.poll_cache(request_id)
+    state = await self.poll_state(request_id)
     x = mx.array(input_data).astype(mx.int64) if self.shard.is_first_layer() else mx.array(input_data)
-    #print(f"Infer Tensor: {x=}")
-    output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, lambda: self.model(x, cache=cache)))
-    #print(f"infer_tensor out -> {output_data}")
+    output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, lambda: self.model(x, **state)))
     return output_data
 
   async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce"):

+ 11 - 2
exo/inference/tinygrad/inference.py

@@ -11,7 +11,7 @@ import numpy as np
 from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
 from exo.download.shard_download import ShardDownloader
 from concurrent.futures import ThreadPoolExecutor
-from .stateful_model import StatefulModel
+from .stateful_model import StatefulModel, make_prompt_state
 from .losses import length_masked_ce_loss
 import asyncio
 
@@ -63,8 +63,15 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     self.shard = None
     self.shard_downloader = shard_downloader
     self.executor = ThreadPoolExecutor(max_workers=1)
+    self.states = OrderedDict()
     self.session = {}
 
+  async def poll_state(self, request_id: str, max_caches=2):
+    if len(self.states) >= self.max_states:
+      self.states.popitem(last=False)
+    state = self.state[request_id]
+    return {"start_pos": state.start, "cache": state.cache}
+
   async def sample(self, x: np.ndarray, temp=TEMPERATURE, top_p: float = 0.0) -> np.ndarray:
     logits = x[:, -1, :]
     def sample_wrapper():
@@ -83,7 +90,9 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
   
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
     await self.ensure_shard(shard)
-    output_data = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), request_id).realize())
+    state = self.poll_state(request_id)
+    output_data = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), **state).realize())
+    self.state[request_id].start += input_data.shape[1]
     return output_data.numpy()
 
   async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss):

+ 1 - 1
exo/inference/tinygrad/models/llama.py

@@ -230,7 +230,7 @@ class Transformer:
       return self.forward_jit(x, Variable("start_pos", 1, self.max_context).bind(start_pos), cache=cache)
     return self.forward_base(x, start_pos, cache=cache)
 
-  def __call__(self, tokens: Tensor, start_pos: Variable, cache: Optional[List[Tensor]] = None):
+  def __call__(self, x: Tensor, start_pos: Variable, cache: Optional[List[Tensor]] = None):
     # TODO: better way to handle the first call v.s. the rest?
     h = self.embed(x)
     return self.forward(h, start_pos, cache=cache)

+ 8 - 2
exo/inference/tinygrad/stateful_model.py

@@ -16,6 +16,13 @@ class ModelState:
     self.cache = cache
     self.start = start
 
+def make_prompt_state(model, shard, x):
+  cache = [create_kv_cache(x, model.layers[i].attention.max_context, model.layers[i].attention.n_kv_heads, model.layers[i].attention.head_dim) for i in range(shard.start_layer, shard.end_layer + 1)]
+
+  return ModelState(cache)
+
+  
+
 class StatefulModel:
   def __init__(self, model, max_states: int = 2):
     super().__init__()
@@ -24,11 +31,10 @@ class StatefulModel:
     self.states = OrderedDict()
  
   def init_cache(self, x: Tensor, request_id: str):
-    cache = [create_kv_cache(x, self.model.layers[i].attention.max_context, self.model.layers[i].attention.n_kv_heads, self.model.layers[i].attention.head_dim) for i in range(self.model.shard.start_layer, self.model.shard.end_layer + 1)]
     if len(self.states) >= self.max_states:
       self.states.popitem(last=False)
 
-    self.states[request_id] = ModelState(cache)
+    self.states[request_id] = make_prompt_state(self.model, self.model.shard)
 
   def __call__(self, x: Tensor, request_id: Optional[str] = None, use_cache: bool = True): 
     h = self.model.embed(x)