Răsfoiți Sursa

update tinygrad

Alex Cheema 7 luni în urmă
părinte
comite
9f6c688d62
2 a modificat fișierele cu 4 adăugiri și 4 ștergeri
  1. 3 3
      exo/inference/tinygrad/inference.py
  2. 1 1
      setup.py

+ 3 - 3
exo/inference/tinygrad/inference.py

@@ -79,8 +79,8 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     return {"start_pos": state.start, "cache": state.cache}
 
   async def sample(self, x: np.ndarray, temp=TEMPERATURE, top_p: float = 0.0) -> np.ndarray:
-    logits = x[:, -1, :]
     def sample_wrapper():
+      logits = x[:, -1, :]
       return sample_logits(Tensor(logits).flatten(), temp, 0, 0.8, top_p, 0.0).realize().numpy().astype(int)
     return await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper)
 
@@ -112,9 +112,9 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
       state = self.poll_state(h, request_id)
       out = self.model.forward(h, **state)
       self.states[request_id].start += x.shape[1]
-      return out.realize()
+      return out.numpy()
     output_data = await asyncio.get_running_loop().run_in_executor(self.executor, wrap_infer)
-    return output_data.numpy(), inference_state
+    return output_data, inference_state
 
   async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss):
     def step(x, y, l):

+ 1 - 1
setup.py

@@ -29,7 +29,7 @@ install_requires = [
   "transformers==4.46.3",
   "uuid==1.30",
   "uvloop==0.21.0",
-  "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@3b26e51fcebfc6576f4e0f99693e6f1406d61d79",
+  "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@ec120ce6b9ce8e4ff4b5692566a683ef240e8bc8",
 ]
 
 extras_require = {