Alex Cheema 8 месяцев назад
Родитель
Сommit
1331ed7679
2 измененных файлов с 9 добавлено и 5 удалено
  1. 6 2
      exo/inference/dummy_inference_engine.py
  2. 3 3
      exo/inference/tokenizers.py

+ 6 - 2
exo/inference/dummy_inference_engine.py

@@ -3,9 +3,10 @@ import numpy as np
 import random
 import string
 import asyncio
-import json
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.shard import Shard
+from exo.inference.tokenizers import DummyTokenizer
+
 def random_string(length: int):
   return ''.join([random.choice(string.ascii_lowercase) for i in range(length)])
   
@@ -18,12 +19,15 @@ class DummyInferenceEngine(InferenceEngine):
     self.eos_token_id = 0
     self.latency_mean = 0.1
     self.latency_stddev = 0.02
+    self.tokenizer = DummyTokenizer()
 
   async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
     return np.random.randint(1, self.vocab_size, size=(1, len(prompt.split())))
   
   async def sample(self, x: np.ndarray) -> np.ndarray:
-    return np.random.randint(1, self.vocab_size)
+    if random.random() < 0.1:
+      return np.array([self.tokenizer.eos_token_id])
+    return np.array([np.random.randint(1, self.vocab_size)])
 
   async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
     return ' '.join([random_string(np.random.randint(1, 34)) for token in tokens])

+ 3 - 3
exo/inference/tokenizers.py

@@ -10,13 +10,13 @@ from exo.helpers import DEBUG
 
 class DummyTokenizer:
   def __init__(self):
-    self.eos_token_id = 0
+    self.eos_token_id = 69
 
   def apply_chat_template(self, messages, tokenize=True, add_generation_prompt=True):
-    return [1, 2, 3]
+    return "dummy_tokenized_prompt"
 
   def decode(self, tokens):
-    return "dummy"
+    return "dummy" * len(tokens)
 
 
 async def resolve_tokenizer(model_id: str):