|
@@ -79,8 +79,8 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
|
|
return {"start_pos": state.start, "cache": state.cache}
|
|
|
|
|
|
async def sample(self, x: np.ndarray, temp=TEMPERATURE, top_p: float = 0.0) -> np.ndarray:
|
|
|
- logits = x[:, -1, :]
|
|
|
def sample_wrapper():
|
|
|
+ logits = x[:, -1, :]
|
|
|
return sample_logits(Tensor(logits).flatten(), temp, 0, 0.8, top_p, 0.0).realize().numpy().astype(int)
|
|
|
return await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper)
|
|
|
|
|
@@ -112,9 +112,9 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
|
|
state = self.poll_state(h, request_id)
|
|
|
out = self.model.forward(h, **state)
|
|
|
self.states[request_id].start += x.shape[1]
|
|
|
- return out.realize()
|
|
|
+ return out.numpy()
|
|
|
output_data = await asyncio.get_running_loop().run_in_executor(self.executor, wrap_infer)
|
|
|
- return output_data.numpy(), inference_state
|
|
|
+ return output_data, inference_state
|
|
|
|
|
|
async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss):
|
|
|
def step(x, y, l):
|