bench.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import asyncio
  2. import time
  3. import uuid
  4. import matplotlib.pyplot as plt
  5. from transformers import AutoTokenizer
  6. from exo.inference.tokenizers import resolve_tokenizer
  7. from exo.inference.inference_engine import InferenceEngine
  8. from exo.inference.shard import Shard
  9. from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
  10. from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
  11. from exo.download.hf.hf_shard_download import HFShardDownloader, RepoProgressEvent
  12. from exo.helpers import pretty_print_bytes_per_second
  13. async def run_bench(inference_engine: InferenceEngine, tokenizer, shard: Shard, num_tokens: int = 500, verbose=True):
  14. req_id = str(uuid.uuid4())
  15. start_time = time.time()
  16. total_tokens = 0
  17. tokens_over_time = []
  18. times = []
  19. prompt = tokenizer.apply_chat_template([{"role": "user", "content": "write an essay about the importance of the internet"}], tokenize=False, add_generation_prompt=True)
  20. if verbose: print(f"Prompt: {prompt}\n", flush=True)
  21. resp, inference_state, is_finished = await inference_engine.infer_prompt(req_id, shard, prompt)
  22. total_tokens += 1
  23. tokens_over_time.append(total_tokens)
  24. times.append(time.time() - start_time)
  25. while not is_finished and total_tokens < num_tokens:
  26. resp, inference_state, is_finished = await inference_engine.infer_tensor(req_id, shard, resp, inference_state)
  27. total_tokens += 1
  28. tokens_over_time.append(total_tokens)
  29. times.append(time.time() - start_time)
  30. if verbose: print(tokenizer.decode(resp), end='', flush=True)
  31. return tokens_over_time, times
  32. async def main():
  33. shard_downloader = HFShardDownloader()
  34. def on_progress(shard: Shard, event: RepoProgressEvent):
  35. print(f"Downloading shard {shard} {pretty_print_bytes_per_second(event.overall_speed)} | {event.overall_eta}")
  36. shard_downloader.on_progress.register("print").on_next(on_progress)
  37. engines = [
  38. # (TinygradDynamicShardInferenceEngine(shard_downloader), "Tinygrad", "mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated"),
  39. # (TorchDynamicShardInferenceEngine(shard_downloader), "Torch", "unsloth/Meta-Llama-3.1-8B-Instruct"),
  40. # (MLXDynamicShardInferenceEngine(shard_downloader), "MLX", "mlx-community/Meta-Llama-3.1-8B-Instruct-abliterated"),
  41. (MLXDynamicShardInferenceEngine(shard_downloader), "MLX", "mlx-community/gemma-2-9b-it-4bit")
  42. ]
  43. plt.figure(figsize=(12, 6))
  44. summary = {}
  45. for engine, name, model_id in engines:
  46. shard = Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32)
  47. tokenizer = await resolve_tokenizer(model_id)
  48. await run_bench(engine, tokenizer, shard, 10)
  49. tokens, times = await run_bench(engine, tokenizer, shard)
  50. plt.plot(times, tokens, label=name)
  51. first_token_time = times[0]
  52. # Calculate sustained TPS using the latter half of the data
  53. mid_point = len(tokens) // 2
  54. sustained_tps = (tokens[-1] - tokens[mid_point]) / (times[-1] - times[mid_point])
  55. peak_tps = max([tokens[i] / times[i] for i in range(1, len(tokens))])
  56. summary[name] = {
  57. "first_token_time": first_token_time,
  58. "sustained_tps": sustained_tps,
  59. "peak_tps": peak_tps
  60. }
  61. plt.xlabel("Time (seconds)")
  62. plt.ylabel("Tokens Generated")
  63. plt.title("Token Generation Over Time")
  64. plt.legend()
  65. plt.grid(True)
  66. plt.savefig("token_generation_comparison.png")
  67. plt.close()
  68. print("\nPerformance Summary:")
  69. for name, metrics in summary.items():
  70. print(f"\n{name}:")
  71. print(f" Time to First Token: {metrics['first_token_time']:.4f} seconds")
  72. print(f" Sustained TPS: {metrics['sustained_tps']:.2f} tokens/second")
  73. print(f" Peak TPS: {metrics['peak_tps']:.2f} tokens/second")
  74. fastest_first_token = min(summary.items(), key=lambda x: x[1]['first_token_time'])
  75. fastest_sustained = max(summary.items(), key=lambda x: x[1]['sustained_tps'])
  76. fastest_peak = max(summary.items(), key=lambda x: x[1]['peak_tps'])
  77. print("\nFastest Engines:")
  78. print(f"Fastest to First Token: {fastest_first_token[0]} ({fastest_first_token[1]['first_token_time']:.4f} seconds)")
  79. print(f"Fastest Sustained TPS: {fastest_sustained[0]} ({fastest_sustained[1]['sustained_tps']:.2f} tokens/second)")
  80. print(f"Fastest Peak TPS: {fastest_peak[0]} ({fastest_peak[1]['peak_tps']:.2f} tokens/second)")
  81. if __name__ == "__main__":
  82. asyncio.run(main())