Browse Source

add crude stats collection with --stats

Alex Cheema 1 year ago
parent
commit
7e7b1eadc5
1 changed files with 13 additions and 2 deletions
  1. 13 2
      exo/main.py

+ 13 - 2
exo/main.py

@@ -2,10 +2,10 @@ import argparse
 import asyncio
 import signal
 import json
-import logging
 import time
 import traceback
 import uuid
+from typing import List
 from exo.networking.manual.manual_discovery import ManualDiscovery
 from exo.networking.manual.network_topology_config import NetworkTopology
 from exo.orchestration.standard_node import StandardNode
@@ -51,6 +51,7 @@ parser.add_argument("--run-model", type=str, help="Specify a model to run direct
 parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
 parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
 parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
+parser.add_argument("--stats", action=argparse.BooleanOptionalAction, help="Enable stats tok/sec")
 args = parser.parse_args()
 print(f"Selected inference engine: {args.inference_engine}")
 
@@ -140,9 +141,19 @@ def preemptively_start_download(request_id: str, opaque_status: str):
       print(f"Failed to preemptively start download: {e}")
       traceback.print_exc()
 
-
 node.on_opaque_status.register("start_download").on_next(preemptively_start_download)
 
+stats = {}
+def collect_tps_stats(request_id: str, tokens: List[int], is_finished: bool):
+  if not request_id in stats:
+    stats[request_id] = { "first_token_time": time.perf_counter_ns() }
+  stats[request_id]["n_tokens"] = len(tokens)
+  stats[request_id]["tps"] = 1e9 * (stats[request_id]["n_tokens"] - 1) / (time.perf_counter_ns() - stats[request_id]["first_token_time"])
+  if is_finished:
+    stats[request_id]["end_time"] = time.perf_counter_ns()
+  print(stats)
+
+node.on_token.register("collect_tps_stats").on_next(collect_tps_stats)
 if args.prometheus_client_port:
   from exo.stats.metrics import start_metrics_server
   start_metrics_server(node, args.prometheus_client_port)