Nel Nibcord vor 4 Monaten
Ursprung
Commit
8f0d19e9b0
1 geänderte Dateien mit 3 neuen und 2 gelöschten Zeilen
  1. 3 2
      exo/inference/tinygrad/inference.py

+ 3 - 2
exo/inference/tinygrad/inference.py

@@ -102,8 +102,9 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     await self.ensure_shard(shard)
     def wrap_infer():
       x = Tensor(input_data)
-      state = self.poll_state(x, request_id)
-      out = self.model(x, **state)
+      h = self.model.embed(x)
+      state = self.poll_state(h, request_id)
+      out = self.model.forward(h, **state)
       self.states[request_id].start += x.shape[1]
       return out.realize()
     output_data = await asyncio.get_running_loop().run_in_executor(self.executor, wrap_infer)