|
@@ -25,10 +25,11 @@ from exo.inference.tokenizers import resolve_tokenizer
|
|
|
from exo.orchestration.node import Node
|
|
|
from exo.models import model_base_shards
|
|
|
from exo.viz.topology_viz import TopologyViz
|
|
|
+import yappi
|
|
|
|
|
|
# parse args
|
|
|
parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
|
|
|
-parser.add_argument("command", nargs="?", choices=["run"], help="Command to run")
|
|
|
+parser.add_argument("command", nargs="?", choices=["run", "bench"], help="Command to run")
|
|
|
parser.add_argument("model_name", nargs="?", help="Model name to run")
|
|
|
parser.add_argument("--node-id", type=str, default=None, help="Node ID")
|
|
|
parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
|
|
@@ -48,6 +49,7 @@ parser.add_argument("--max-generate-tokens", type=int, default=10000, help="Max
|
|
|
parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use (mlx, tinygrad, or dummy)")
|
|
|
parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
|
|
|
parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
|
|
|
+parser.add_argument("--bench", action=argparse.BooleanOptionalAction, help="Benchmark the model")
|
|
|
parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
|
|
|
parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
|
|
|
parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
|
|
@@ -123,7 +125,7 @@ api = ChatGPTAPI(
|
|
|
inference_engine.__class__.__name__,
|
|
|
response_timeout=args.chatgpt_api_response_timeout,
|
|
|
on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None
|
|
|
-)
|
|
|
+) if not args.bench else None
|
|
|
node.on_token.register("update_topology_viz").on_next(
|
|
|
lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None
|
|
|
)
|
|
@@ -143,24 +145,7 @@ def preemptively_start_download(request_id: str, opaque_status: str):
|
|
|
|
|
|
node.on_opaque_status.register("start_download").on_next(preemptively_start_download)
|
|
|
|
|
|
-stats = {}
|
|
|
-def collect_tps_stats(request_id: str, tokens: List[int], is_finished: bool):
|
|
|
- if not request_id in stats:
|
|
|
- stats[request_id] = { "first_token_time": time.perf_counter_ns() }
|
|
|
- stats[request_id]["n_tokens"] = len(tokens)
|
|
|
- stats[request_id]["tps"] = 1e9 * (stats[request_id]["n_tokens"] - 1) / (time.perf_counter_ns() - stats[request_id]["first_token_time"])
|
|
|
- if is_finished:
|
|
|
- stats[request_id]["end_time"] = time.perf_counter_ns()
|
|
|
- print(stats)
|
|
|
-
|
|
|
-node.on_token.register("collect_tps_stats").on_next(collect_tps_stats)
|
|
|
-if args.prometheus_client_port:
|
|
|
- from exo.stats.metrics import start_metrics_server
|
|
|
- start_metrics_server(node, args.prometheus_client_port)
|
|
|
-
|
|
|
last_broadcast_time = 0
|
|
|
-
|
|
|
-
|
|
|
def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
|
|
|
global last_broadcast_time
|
|
|
current_time = time.time()
|
|
@@ -212,6 +197,59 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
|
|
|
finally:
|
|
|
node.on_token.deregister(callback_id)
|
|
|
|
|
|
+async def bench(node: Node, inference_engine: InferenceEngine):
|
|
|
+ prompt = "write an essay about war"
|
|
|
+ model_name = "llama-3.2-3b"
|
|
|
+ shard = model_base_shards.get(model_name, {}).get(inference_engine.__class__.__name__)
|
|
|
+ if not shard:
|
|
|
+ print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
|
|
|
+ return
|
|
|
+ tokenizer = await resolve_tokenizer(shard.model_id)
|
|
|
+ request_id = str(uuid.uuid4())
|
|
|
+ callback_id = f"cli-wait-response-{request_id}"
|
|
|
+ callback = node.on_token.register(callback_id)
|
|
|
+ if topology_viz:
|
|
|
+ topology_viz.update_prompt(request_id, prompt)
|
|
|
+ prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
|
|
|
+
|
|
|
+ stats = { "start_time": time.perf_counter_ns() }
|
|
|
+ def collect_tps_stats(request_id: str, tokens: List[int], is_finished: bool):
|
|
|
+ if not request_id in stats:
|
|
|
+ stats[request_id] = { "first_token_time": time.perf_counter_ns() }
|
|
|
+ stats[request_id]["n_tokens"] = len(tokens)
|
|
|
+ stats[request_id]["tps"] = 1e9 * (stats[request_id]["n_tokens"] - 1) / (time.perf_counter_ns() - stats[request_id]["first_token_time"])
|
|
|
+ if is_finished:
|
|
|
+ stats[request_id]["end_time"] = time.perf_counter_ns()
|
|
|
+ print(stats)
|
|
|
+ node.on_token.register("collect_tps_stats").on_next(collect_tps_stats)
|
|
|
+
|
|
|
+ try:
|
|
|
+ print(f"Processing prompt: {prompt}")
|
|
|
+ await node.process_prompt(shard, prompt, None, request_id=request_id)
|
|
|
+
|
|
|
+ _, tokens, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=300)
|
|
|
+
|
|
|
+ print("\nGenerated response:")
|
|
|
+ print(tokenizer.decode(tokens))
|
|
|
+
|
|
|
+ # Write benchmark results to file
|
|
|
+ prompt_tokens = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=True, add_generation_prompt=True)
|
|
|
+ benchmark_results = {
|
|
|
+ "total_time_ns": stats[request_id]["end_time"] - stats["start_time"],
|
|
|
+ "first_token_latency_ns": stats[request_id]["first_token_time"] - stats["start_time"],
|
|
|
+ "tokens_generated": stats[request_id]["n_tokens"],
|
|
|
+ "prompt_num_tokens": len(prompt_tokens),
|
|
|
+ "prompt_tokens_per_second": 1e9 * len(prompt_tokens) / (stats[request_id]["first_token_time"] - stats["start_time"]),
|
|
|
+ "generation_tokens_per_second": 1e9 * (stats[request_id]["n_tokens"] - 1) / (stats[request_id]["end_time"] - stats[request_id]["first_token_time"])
|
|
|
+ }
|
|
|
+ with open("benchmark_results.json", "w") as f:
|
|
|
+ json.dump(benchmark_results, f, indent=2)
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error processing prompt: {str(e)}")
|
|
|
+ traceback.print_exc()
|
|
|
+ finally:
|
|
|
+ node.on_token.deregister(callback_id)
|
|
|
+
|
|
|
|
|
|
async def main():
|
|
|
loop = asyncio.get_running_loop()
|
|
@@ -225,7 +263,9 @@ async def main():
|
|
|
|
|
|
await node.start(wait_for_peers=args.wait_for_peers)
|
|
|
|
|
|
- if args.command == "run" or args.run_model:
|
|
|
+ if args.command == "bench" or args.bench:
|
|
|
+ await bench(node, inference_engine)
|
|
|
+ elif args.command == "run" or args.run_model:
|
|
|
model_name = args.model_name or args.run_model
|
|
|
if not model_name:
|
|
|
print("Error: Model name is required when using 'run' command or --run-model")
|
|
@@ -240,7 +280,17 @@ def run():
|
|
|
loop = asyncio.new_event_loop()
|
|
|
asyncio.set_event_loop(loop)
|
|
|
try:
|
|
|
+ yappi.set_clock_type('wall')
|
|
|
+ yappi.start()
|
|
|
+
|
|
|
loop.run_until_complete(main())
|
|
|
+
|
|
|
+ yappi.stop()
|
|
|
+ # Print profiling results
|
|
|
+ func_stats = yappi.get_func_stats()
|
|
|
+ func_stats.sort('ttot').print_all()
|
|
|
+ func_stats.save('callgrind.out', type='callgrind')
|
|
|
+
|
|
|
except KeyboardInterrupt:
|
|
|
print("Received keyboard interrupt. Shutting down...")
|
|
|
finally:
|