|
@@ -19,29 +19,24 @@ class DummyInferenceEngine(InferenceEngine):
|
|
self.eos_token_id = 0
|
|
self.eos_token_id = 0
|
|
self.latency_mean = 0.1
|
|
self.latency_mean = 0.1
|
|
self.latency_stddev = 0.02
|
|
self.latency_stddev = 0.02
|
|
|
|
+ self.num_generate_dummy_tokens = 10
|
|
self.tokenizer = DummyTokenizer()
|
|
self.tokenizer = DummyTokenizer()
|
|
|
|
|
|
async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
|
|
async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
|
|
return np.array(self.tokenizer.encode(prompt))
|
|
return np.array(self.tokenizer.encode(prompt))
|
|
|
|
|
|
async def sample(self, x: np.ndarray) -> np.ndarray:
|
|
async def sample(self, x: np.ndarray) -> np.ndarray:
|
|
- if random.random() < 0.1:
|
|
|
|
- return np.array([self.tokenizer.eos_token_id])
|
|
|
|
- return np.array([np.random.randint(1, self.vocab_size)])
|
|
|
|
|
|
+ print('sample', x)
|
|
|
|
+ if x[0] > self.num_generate_dummy_tokens: return np.array([self.tokenizer.eos_token_id])
|
|
|
|
+ return x
|
|
|
|
|
|
async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
|
|
async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
|
|
return self.tokenizer.decode(tokens)
|
|
return self.tokenizer.decode(tokens)
|
|
|
|
|
|
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
|
|
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
|
|
await self.ensure_shard(shard)
|
|
await self.ensure_shard(shard)
|
|
- sequence_length = input_data.shape[0 if self.shard.is_first_layer() else 1]
|
|
|
|
- output = np.random.random(size=(1, sequence_length, self.vocab_size if self.shard.is_last_layer() else self.hidden_size))
|
|
|
|
- return output
|
|
|
|
|
|
+ return input_data + 1 if self.shard.is_last_layer() else input_data
|
|
|
|
|
|
async def ensure_shard(self, shard: Shard):
|
|
async def ensure_shard(self, shard: Shard):
|
|
- if self.shard == shard:
|
|
|
|
- return
|
|
|
|
- # Simulate shard loading without making any API calls
|
|
|
|
- await asyncio.sleep(0.1) # Simulate a short delay
|
|
|
|
|
|
+ if self.shard == shard: return
|
|
self.shard = shard
|
|
self.shard = shard
|
|
- print(f"DummyInferenceEngine: Simulated loading of shard {shard.model_id}")
|
|
|