Browse Source

throttle repo progress events and only send them out if something changed

Alex Cheema 5 months ago
parent
commit
3675804f4d
1 changed files with 7 additions and 5 deletions
  1. 7 5
      exo/main.py

+ 7 - 5
exo/main.py

@@ -206,14 +206,16 @@ def preemptively_load_shard(request_id: str, opaque_status: str):
       traceback.print_exc()
       traceback.print_exc()
 node.on_opaque_status.register("preemptively_load_shard").on_next(preemptively_load_shard)
 node.on_opaque_status.register("preemptively_load_shard").on_next(preemptively_load_shard)
 
 
-last_broadcast_time = 0
+last_events: dict[str, tuple[float, RepoProgressEvent]] = {}
 def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
 def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
-  global last_broadcast_time
+  global last_events
   current_time = time.time()
   current_time = time.time()
   if event.status == "not_started": return
   if event.status == "not_started": return
-  if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
-    last_broadcast_time = current_time
-    asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
+  last_event = last_events.get(shard.model_id)
+  if last_event and last_event[1].status == "complete" and event.status == "complete": return
+  if last_event and last_event[0] == event.status and current_time - last_event[0] < 0.2: return
+  last_events[shard.model_id] = (current_time, event)
+  asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
 shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
 shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
 
 
 async def run_model_cli(node: Node, model_name: str, prompt: str):
 async def run_model_cli(node: Node, model_name: str, prompt: str):