|
@@ -64,10 +64,10 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
|
self.shard_downloader = shard_downloader
|
|
self.shard_downloader = shard_downloader
|
|
self.executor = ThreadPoolExecutor(max_workers=1)
|
|
self.executor = ThreadPoolExecutor(max_workers=1)
|
|
|
|
|
|
- async def sample(self, x: np.ndarray) -> np.ndarray:
|
|
|
|
|
|
+ async def sample(self, x: np.ndarray, temp=TEMPERATURE, top_p: float = 0.0) -> np.ndarray:
|
|
logits = x[:, -1, :]
|
|
logits = x[:, -1, :]
|
|
def sample_wrapper():
|
|
def sample_wrapper():
|
|
- return sample_logits(Tensor(logits).flatten(), TEMPERATURE, 0, 0.8, 0.0, 0.0).realize()
|
|
|
|
|
|
+ return sample_logits(Tensor(logits).flatten(), temp, 0, 0.8, top_p, 0.0).realize()
|
|
out = await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper)
|
|
out = await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper)
|
|
return out.numpy()
|
|
return out.numpy()
|
|
|
|
|