Ver Fonte

proper prompt

Alex Cheema há 5 meses atrás
pai
commit
d93dccca5b
1 ficheiros alterados com 13 adições e 6 exclusões
  1. 13 6
      test/bench.py

+ 13 - 6
test/bench.py

@@ -2,6 +2,8 @@ import asyncio
 import time
 import uuid
 import matplotlib.pyplot as plt
+from transformers import AutoTokenizer
+from exo.inference.tokenizers import resolve_tokenizer
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.shard import Shard
 from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
@@ -9,14 +11,16 @@ from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
 from exo.download.hf.hf_shard_download import HFShardDownloader, RepoProgressEvent
 from exo.helpers import pretty_print_bytes_per_second
 
-async def run_bench(inference_engine: InferenceEngine, shard: Shard, num_tokens: int = 200):
+async def run_bench(inference_engine: InferenceEngine, tokenizer, shard: Shard, num_tokens: int = 500, verbose=True):
   req_id = str(uuid.uuid4())
   start_time = time.time()
   total_tokens = 0
   tokens_over_time = []
   times = []
 
-  resp, inference_state, is_finished = await inference_engine.infer_prompt(req_id, shard, "who are you?")
+  prompt = tokenizer.apply_chat_template([{"role": "user", "content": "write an essay about the importance of the internet"}], tokenize=False, add_generation_prompt=True)
+  if verbose: print(f"Prompt: {prompt}\n", flush=True)
+  resp, inference_state, is_finished = await inference_engine.infer_prompt(req_id, shard, prompt)
   total_tokens += 1
   tokens_over_time.append(total_tokens)
   times.append(time.time() - start_time)
@@ -26,6 +30,7 @@ async def run_bench(inference_engine: InferenceEngine, shard: Shard, num_tokens:
     total_tokens += 1
     tokens_over_time.append(total_tokens)
     times.append(time.time() - start_time)
+    if verbose: print(tokenizer.decode(resp), end='', flush=True)
 
   return tokens_over_time, times
 
@@ -36,9 +41,10 @@ async def main():
   shard_downloader.on_progress.register("print").on_next(on_progress)
 
   engines = [
-    (TinygradDynamicShardInferenceEngine(shard_downloader), "Tinygrad", "mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated"),
+    # (TinygradDynamicShardInferenceEngine(shard_downloader), "Tinygrad", "mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated"),
     # (TorchDynamicShardInferenceEngine(shard_downloader), "Torch", "unsloth/Meta-Llama-3.1-8B-Instruct"),
-    (MLXDynamicShardInferenceEngine(shard_downloader), "MLX", "mlx-community/Meta-Llama-3.1-8B-Instruct-abliterated")
+    # (MLXDynamicShardInferenceEngine(shard_downloader), "MLX", "mlx-community/Meta-Llama-3.1-8B-Instruct-abliterated"),
+    (MLXDynamicShardInferenceEngine(shard_downloader), "MLX", "mlx-community/gemma-2-9b-it-4bit")
   ]
 
   plt.figure(figsize=(12, 6))
@@ -46,8 +52,9 @@ async def main():
 
   for engine, name, model_id in engines:
     shard = Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32)
-    await run_bench(engine, shard, 10)
-    tokens, times = await run_bench(engine, shard)
+    tokenizer = await resolve_tokenizer(model_id)
+    await run_bench(engine, tokenizer, shard, 10)
+    tokens, times = await run_bench(engine, tokenizer, shard)
 
     plt.plot(times, tokens, label=name)