|
|
@@ -2,10 +2,10 @@ import argparse
|
|
|
import asyncio
|
|
|
import signal
|
|
|
import json
|
|
|
-import logging
|
|
|
import time
|
|
|
import traceback
|
|
|
import uuid
|
|
|
+from typing import List
|
|
|
from exo.networking.manual.manual_discovery import ManualDiscovery
|
|
|
from exo.networking.manual.network_topology_config import NetworkTopology
|
|
|
from exo.orchestration.standard_node import StandardNode
|
|
|
@@ -51,6 +51,7 @@ parser.add_argument("--run-model", type=str, help="Specify a model to run direct
|
|
|
parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
|
|
|
parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
|
|
|
parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
|
|
|
+parser.add_argument("--stats", action=argparse.BooleanOptionalAction, help="Enable stats tok/sec")
|
|
|
args = parser.parse_args()
|
|
|
print(f"Selected inference engine: {args.inference_engine}")
|
|
|
|
|
|
@@ -140,9 +141,19 @@ def preemptively_start_download(request_id: str, opaque_status: str):
|
|
|
print(f"Failed to preemptively start download: {e}")
|
|
|
traceback.print_exc()
|
|
|
|
|
|
-
|
|
|
node.on_opaque_status.register("start_download").on_next(preemptively_start_download)
|
|
|
|
|
|
+stats = {}
|
|
|
+def collect_tps_stats(request_id: str, tokens: List[int], is_finished: bool):
|
|
|
+ if not request_id in stats:
|
|
|
+ stats[request_id] = { "first_token_time": time.perf_counter_ns() }
|
|
|
+ stats[request_id]["n_tokens"] = len(tokens)
|
|
|
+ stats[request_id]["tps"] = 1e9 * (stats[request_id]["n_tokens"] - 1) / (time.perf_counter_ns() - stats[request_id]["first_token_time"])
|
|
|
+ if is_finished:
|
|
|
+ stats[request_id]["end_time"] = time.perf_counter_ns()
|
|
|
+ print(stats)
|
|
|
+
|
|
|
+node.on_token.register("collect_tps_stats").on_next(collect_tps_stats)
|
|
|
if args.prometheus_client_port:
|
|
|
from exo.stats.metrics import start_metrics_server
|
|
|
start_metrics_server(node, args.prometheus_client_port)
|