Browse Source

fix dummy inference

Alex Cheema 7 months ago
parent
commit
f601a83070
2 changed files with 7 additions and 2 deletions
  1. 2 2
      exo/inference/dummy_inference_engine.py
  2. 5 0
      exo/inference/tokenizers.py

+ 2 - 2
exo/inference/dummy_inference_engine.py

@@ -22,7 +22,7 @@ class DummyInferenceEngine(InferenceEngine):
     self.tokenizer = DummyTokenizer()
 
   async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
-    return np.random.randint(1, self.vocab_size, size=(1, len(prompt.split())))
+    return np.array(self.tokenizer.encode(prompt))
   
   async def sample(self, x: np.ndarray) -> np.ndarray:
     if random.random() < 0.1:
@@ -30,7 +30,7 @@ class DummyInferenceEngine(InferenceEngine):
     return np.array([np.random.randint(1, self.vocab_size)])
 
   async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
-    return ' '.join([random_string(np.random.randint(1, 34)) for token in tokens])
+    return self.tokenizer.decode(tokens)
 
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
     await self.ensure_shard(shard)

+ 5 - 0
exo/inference/tokenizers.py

@@ -4,6 +4,7 @@ from os import PathLike
 from pathlib import Path
 from typing import Union
 from transformers import AutoTokenizer, AutoProcessor
+import numpy as np
 from exo.download.hf.hf_helpers import get_local_snapshot_dir
 from exo.helpers import DEBUG
 
@@ -11,10 +12,14 @@ from exo.helpers import DEBUG
 class DummyTokenizer:
   def __init__(self):
     self.eos_token_id = 69
+    self.vocab_size = 1000
 
   def apply_chat_template(self, messages, tokenize=True, add_generation_prompt=True):
     return "dummy_tokenized_prompt"
 
+  def encode(self, text):
+    return np.random.randint(1, self.vocab_size, size=(1, len(text.split())))
+
   def decode(self, tokens):
     return "dummy" * len(tokens)