|
@@ -2,6 +2,8 @@ import asyncio
|
|
import time
|
|
import time
|
|
import uuid
|
|
import uuid
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.pyplot as plt
|
|
|
|
+from transformers import AutoTokenizer
|
|
|
|
+from exo.inference.tokenizers import resolve_tokenizer
|
|
from exo.inference.inference_engine import InferenceEngine
|
|
from exo.inference.inference_engine import InferenceEngine
|
|
from exo.inference.shard import Shard
|
|
from exo.inference.shard import Shard
|
|
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
|
|
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
|
|
@@ -9,14 +11,16 @@ from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
|
|
from exo.download.hf.hf_shard_download import HFShardDownloader, RepoProgressEvent
|
|
from exo.download.hf.hf_shard_download import HFShardDownloader, RepoProgressEvent
|
|
from exo.helpers import pretty_print_bytes_per_second
|
|
from exo.helpers import pretty_print_bytes_per_second
|
|
|
|
|
|
-async def run_bench(inference_engine: InferenceEngine, shard: Shard, num_tokens: int = 200):
|
|
|
|
|
|
+async def run_bench(inference_engine: InferenceEngine, tokenizer, shard: Shard, num_tokens: int = 500, verbose=True):
|
|
req_id = str(uuid.uuid4())
|
|
req_id = str(uuid.uuid4())
|
|
start_time = time.time()
|
|
start_time = time.time()
|
|
total_tokens = 0
|
|
total_tokens = 0
|
|
tokens_over_time = []
|
|
tokens_over_time = []
|
|
times = []
|
|
times = []
|
|
|
|
|
|
- resp, inference_state, is_finished = await inference_engine.infer_prompt(req_id, shard, "who are you?")
|
|
|
|
|
|
+ prompt = tokenizer.apply_chat_template([{"role": "user", "content": "write an essay about the importance of the internet"}], tokenize=False, add_generation_prompt=True)
|
|
|
|
+ if verbose: print(f"Prompt: {prompt}\n", flush=True)
|
|
|
|
+ resp, inference_state, is_finished = await inference_engine.infer_prompt(req_id, shard, prompt)
|
|
total_tokens += 1
|
|
total_tokens += 1
|
|
tokens_over_time.append(total_tokens)
|
|
tokens_over_time.append(total_tokens)
|
|
times.append(time.time() - start_time)
|
|
times.append(time.time() - start_time)
|
|
@@ -26,6 +30,7 @@ async def run_bench(inference_engine: InferenceEngine, shard: Shard, num_tokens:
|
|
total_tokens += 1
|
|
total_tokens += 1
|
|
tokens_over_time.append(total_tokens)
|
|
tokens_over_time.append(total_tokens)
|
|
times.append(time.time() - start_time)
|
|
times.append(time.time() - start_time)
|
|
|
|
+ if verbose: print(tokenizer.decode(resp), end='', flush=True)
|
|
|
|
|
|
return tokens_over_time, times
|
|
return tokens_over_time, times
|
|
|
|
|
|
@@ -36,9 +41,10 @@ async def main():
|
|
shard_downloader.on_progress.register("print").on_next(on_progress)
|
|
shard_downloader.on_progress.register("print").on_next(on_progress)
|
|
|
|
|
|
engines = [
|
|
engines = [
|
|
- (TinygradDynamicShardInferenceEngine(shard_downloader), "Tinygrad", "mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated"),
|
|
|
|
|
|
+ # (TinygradDynamicShardInferenceEngine(shard_downloader), "Tinygrad", "mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated"),
|
|
# (TorchDynamicShardInferenceEngine(shard_downloader), "Torch", "unsloth/Meta-Llama-3.1-8B-Instruct"),
|
|
# (TorchDynamicShardInferenceEngine(shard_downloader), "Torch", "unsloth/Meta-Llama-3.1-8B-Instruct"),
|
|
- (MLXDynamicShardInferenceEngine(shard_downloader), "MLX", "mlx-community/Meta-Llama-3.1-8B-Instruct-abliterated")
|
|
|
|
|
|
+ # (MLXDynamicShardInferenceEngine(shard_downloader), "MLX", "mlx-community/Meta-Llama-3.1-8B-Instruct-abliterated"),
|
|
|
|
+ (MLXDynamicShardInferenceEngine(shard_downloader), "MLX", "mlx-community/gemma-2-9b-it-4bit")
|
|
]
|
|
]
|
|
|
|
|
|
plt.figure(figsize=(12, 6))
|
|
plt.figure(figsize=(12, 6))
|
|
@@ -46,8 +52,9 @@ async def main():
|
|
|
|
|
|
for engine, name, model_id in engines:
|
|
for engine, name, model_id in engines:
|
|
shard = Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32)
|
|
shard = Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32)
|
|
- await run_bench(engine, shard, 10)
|
|
|
|
- tokens, times = await run_bench(engine, shard)
|
|
|
|
|
|
+ tokenizer = await resolve_tokenizer(model_id)
|
|
|
|
+ await run_bench(engine, tokenizer, shard, 10)
|
|
|
|
+ tokens, times = await run_bench(engine, tokenizer, shard)
|
|
|
|
|
|
plt.plot(times, tokens, label=name)
|
|
plt.plot(times, tokens, label=name)
|
|
|
|
|