|
@@ -1,16 +1,9 @@
|
|
|
from typing import Optional, Tuple, TYPE_CHECKING
|
|
|
import numpy as np
|
|
|
-import random
|
|
|
-import string
|
|
|
-import asyncio
|
|
|
from exo.inference.inference_engine import InferenceEngine
|
|
|
from exo.inference.shard import Shard
|
|
|
from exo.inference.tokenizers import DummyTokenizer
|
|
|
|
|
|
-def random_string(length: int):
|
|
|
- return ''.join([random.choice(string.ascii_lowercase) for i in range(length)])
|
|
|
-
|
|
|
-
|
|
|
class DummyInferenceEngine(InferenceEngine):
|
|
|
def __init__(self):
|
|
|
self.shard = None
|
|
@@ -26,7 +19,6 @@ class DummyInferenceEngine(InferenceEngine):
|
|
|
return np.array(self.tokenizer.encode(prompt))
|
|
|
|
|
|
async def sample(self, x: np.ndarray) -> np.ndarray:
|
|
|
- print('sample', x)
|
|
|
if x[0] > self.num_generate_dummy_tokens: return np.array([self.tokenizer.eos_token_id])
|
|
|
return x
|
|
|
|