|
@@ -3,9 +3,10 @@ import numpy as np
|
|
|
import random
|
|
|
import string
|
|
|
import asyncio
|
|
|
-import json
|
|
|
from exo.inference.inference_engine import InferenceEngine
|
|
|
from exo.inference.shard import Shard
|
|
|
+from exo.inference.tokenizers import DummyTokenizer
|
|
|
+
|
|
|
def random_string(length: int):
|
|
|
return ''.join([random.choice(string.ascii_lowercase) for i in range(length)])
|
|
|
|
|
@@ -18,15 +19,18 @@ class DummyInferenceEngine(InferenceEngine):
|
|
|
self.eos_token_id = 0
|
|
|
self.latency_mean = 0.1
|
|
|
self.latency_stddev = 0.02
|
|
|
+ self.tokenizer = DummyTokenizer()
|
|
|
|
|
|
async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
|
|
|
- return np.random.randint(1, self.vocab_size, size=(1, len(prompt.split())))
|
|
|
+ return np.array(self.tokenizer.encode(prompt))
|
|
|
|
|
|
async def sample(self, x: np.ndarray) -> np.ndarray:
|
|
|
- return np.random.randint(1, self.vocab_size)
|
|
|
+ if random.random() < 0.1:
|
|
|
+ return np.array([self.tokenizer.eos_token_id])
|
|
|
+ return np.array([np.random.randint(1, self.vocab_size)])
|
|
|
|
|
|
async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
|
|
|
- return ' '.join([random_string(np.random.randint(1, 34)) for token in tokens])
|
|
|
+ return self.tokenizer.decode(tokens)
|
|
|
|
|
|
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
|
|
|
await self.ensure_shard(shard)
|