|
@@ -111,8 +111,6 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
|
Tensor.training = False
|
|
Tensor.training = False
|
|
return self.session['loss'](self.model, x, y, l)
|
|
return self.session['loss'](self.model, x, y, l)
|
|
await self.ensure_shard(shard)
|
|
await self.ensure_shard(shard)
|
|
- await self.ensure_session('loss', lambda: loss)
|
|
|
|
- await self.ensure_session('jit', lambda: TinyJit(step))
|
|
|
|
score = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.session['jit'](Tensor(inputs), targets, lengths))
|
|
score = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.session['jit'](Tensor(inputs), targets, lengths))
|
|
out = score.numpy()
|
|
out = score.numpy()
|
|
return out
|
|
return out
|
|
@@ -126,9 +124,6 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
|
self.session['opt'].step()
|
|
self.session['opt'].step()
|
|
return score
|
|
return score
|
|
await self.ensure_shard(shard)
|
|
await self.ensure_shard(shard)
|
|
- await self.ensure_session('loss', lambda: loss)
|
|
|
|
- await self.ensure_session('opt', lambda: opt(nn.state.get_parameters(self.model.model), lr=lr))
|
|
|
|
- await self.ensure_session('jit', lambda: TinyJit(step))
|
|
|
|
|
|
|
|
score = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.session['jit'](Tensor(inputs), targets, lengths).realize())
|
|
score = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.session['jit'](Tensor(inputs), targets, lengths).realize())
|
|
|
|
|