|
@@ -18,7 +18,7 @@ class DummyInferenceEngine(InferenceEngine):
|
|
|
async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
|
|
|
return np.array(self.tokenizer.encode(prompt))
|
|
|
|
|
|
- async def sample(self, x: np.ndarray) -> np.ndarray:
|
|
|
+ async def sample(self, x: np.ndarray, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
|
|
|
if x[0] > self.num_generate_dummy_tokens: return np.array([self.tokenizer.eos_token_id])
|
|
|
return x
|
|
|
|