Browse Source

Fixing tinygrad model

Nel Nibcord 8 months ago
parent
commit
67f5ae25a5
2 changed files with 13 additions and 12 deletions
  1. 12 11
      exo/inference/tinygrad/inference.py
  2. 1 1
      exo/inference/tinygrad/stateful_model.py

+ 12 - 11
exo/inference/tinygrad/inference.py

@@ -13,6 +13,7 @@ from exo.download.shard_download import ShardDownloader
 from concurrent.futures import ThreadPoolExecutor
 from .stateful_model import StatefulModel, make_prompt_state
 from .losses import length_masked_ce_loss
+from collections import OrderedDict
 import asyncio
 
 Tensor.no_grad = False
@@ -66,14 +67,14 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     self.states = OrderedDict()
     self.session = {}
 
-  async def poll_state(self, request_id: str, max_caches=2):
+  def poll_state(self, x, request_id: str, max_states=2):
     if request_id not in self.states:
-      if len(self.states) >= self.max_states:
+      if len(self.states) >= max_states:
         self.states.popitem(last=False)
-      make_prompt_state(
+      self.states[request_id] = make_prompt_state(x, self.model, self.shard)
     else:
       self.states.move_to_end(request_id)
-    state = self.state[request_id]
+    state = self.states[request_id]
     return {"start_pos": state.start, "cache": state.cache}
 
   async def sample(self, x: np.ndarray, temp=TEMPERATURE, top_p: float = 0.0) -> np.ndarray:
@@ -94,13 +95,13 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
   
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
     await self.ensure_shard(shard)
-    state = self.poll_state(request_id)
-    def wrap_infer(data):
-      x = Tensor(data)
-      state = poll_state(request_id, x)
-      self.model(x, **state)
-      self.state[request_id].start += x.shape[1]
-    output_data = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: wrap_infer(input_data).realize())
+    def wrap_infer():
+      x = Tensor(input_data)
+      state = self.poll_state(x, request_id)
+      out = self.model(x, **state)
+      self.states[request_id].start += x.shape[1]
+      return out.realize()
+    output_data = await asyncio.get_running_loop().run_in_executor(self.executor, wrap_infer)
     return output_data.numpy()
 
   async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss):

+ 1 - 1
exo/inference/tinygrad/stateful_model.py

@@ -16,7 +16,7 @@ class ModelState:
     self.cache = cache
     self.start = start
 
-def make_prompt_state(model, shard, x):
+def make_prompt_state(x, model, shard):
   cache = [create_kv_cache(x, model.layers[i].attention.max_context, model.layers[i].attention.n_kv_heads, model.layers[i].attention.head_dim) for i in range(shard.start_layer, shard.end_layer + 1)]
 
   return ModelState(cache)