|
@@ -22,7 +22,7 @@ class DummyInferenceEngine(InferenceEngine):
|
|
|
self.tokenizer = DummyTokenizer()
|
|
|
|
|
|
async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
|
|
|
- return np.random.randint(1, self.vocab_size, size=(1, len(prompt.split())))
|
|
|
+ return np.array(self.tokenizer.encode(prompt))
|
|
|
|
|
|
async def sample(self, x: np.ndarray) -> np.ndarray:
|
|
|
if random.random() < 0.1:
|
|
@@ -30,7 +30,7 @@ class DummyInferenceEngine(InferenceEngine):
|
|
|
return np.array([np.random.randint(1, self.vocab_size)])
|
|
|
|
|
|
async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
|
|
|
- return ' '.join([random_string(np.random.randint(1, 34)) for token in tokens])
|
|
|
+ return self.tokenizer.decode(tokens)
|
|
|
|
|
|
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
|
|
|
await self.ensure_shard(shard)
|