|
@@ -102,8 +102,9 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
|
|
await self.ensure_shard(shard)
|
|
|
def wrap_infer():
|
|
|
x = Tensor(input_data)
|
|
|
- state = self.poll_state(x, request_id)
|
|
|
- out = self.model(x, **state)
|
|
|
+ h = self.model.embed(x)
|
|
|
+ state = self.poll_state(h, request_id)
|
|
|
+ out = self.model.forward(h, **state)
|
|
|
self.states[request_id].start += x.shape[1]
|
|
|
return out.realize()
|
|
|
output_data = await asyncio.get_running_loop().run_in_executor(self.executor, wrap_infer)
|