Browse Source

fix dummy generate so it doesnt have any randomness

Alex Cheema 7 months ago
parent
commit
2502ed20d2
2 changed files with 7 additions and 12 deletions
  1. 6 11
      exo/inference/dummy_inference_engine.py
  2. 1 1
      exo/inference/tokenizers.py

+ 6 - 11
exo/inference/dummy_inference_engine.py

@@ -19,29 +19,24 @@ class DummyInferenceEngine(InferenceEngine):
     self.eos_token_id = 0
     self.latency_mean = 0.1
     self.latency_stddev = 0.02
+    self.num_generate_dummy_tokens = 10
     self.tokenizer = DummyTokenizer()
 
   async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
     return np.array(self.tokenizer.encode(prompt))
   
   async def sample(self, x: np.ndarray) -> np.ndarray:
-    if random.random() < 0.1:
-      return np.array([self.tokenizer.eos_token_id])
-    return np.array([np.random.randint(1, self.vocab_size)])
+    print('sample', x)
+    if x[0] > self.num_generate_dummy_tokens: return np.array([self.tokenizer.eos_token_id])
+    return x
 
   async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
     return self.tokenizer.decode(tokens)
 
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
     await self.ensure_shard(shard)
-    sequence_length = input_data.shape[0 if self.shard.is_first_layer() else 1]
-    output = np.random.random(size=(1, sequence_length, self.vocab_size if self.shard.is_last_layer() else self.hidden_size))
-    return output
+    return input_data + 1 if self.shard.is_last_layer() else input_data
 
   async def ensure_shard(self, shard: Shard):
-    if self.shard == shard:
-      return
-    # Simulate shard loading without making any API calls
-    await asyncio.sleep(0.1)  # Simulate a short delay
+    if self.shard == shard: return
     self.shard = shard
-    print(f"DummyInferenceEngine: Simulated loading of shard {shard.model_id}")

+ 1 - 1
exo/inference/tokenizers.py

@@ -18,7 +18,7 @@ class DummyTokenizer:
     return "dummy_tokenized_prompt"
 
   def encode(self, text):
-    return np.random.randint(1, self.vocab_size, size=(1, len(text.split())))
+    return np.array([1])
 
   def decode(self, tokens):
     return "dummy" * len(tokens)