瀏覽代碼

Fixing tinygrad model

Nel Nibcord 7 月之前
父節點
當前提交
bfa3b36be5
共有 1 個文件被更改,包括 12 次插入4 次删除
  1. 12 4
      exo/inference/tinygrad/inference.py

+ 12 - 4
exo/inference/tinygrad/inference.py

@@ -67,8 +67,12 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     self.session = {}
 
   async def poll_state(self, request_id: str, max_caches=2):
-    if len(self.states) >= self.max_states:
-      self.states.popitem(last=False)
+    if request_id not in self.states:
+      if len(self.states) >= self.max_states:
+        self.states.popitem(last=False)
+      make_prompt_state(
+    else:
+      self.states.move_to_end(request_id)
     state = self.state[request_id]
     return {"start_pos": state.start, "cache": state.cache}
 
@@ -91,8 +95,12 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
     await self.ensure_shard(shard)
     state = self.poll_state(request_id)
-    output_data = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), **state).realize())
-    self.state[request_id].start += input_data.shape[1]
+    def wrap_infer(data):
+      x = Tensor(data)
+      state = poll_state(request_id, x)
+      self.model(x, **state)
+      self.state[request_id].start += x.shape[1]
+    output_data = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: wrap_infer(input_data).realize())
     return output_data.numpy()
 
   async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss):