|
@@ -206,14 +206,16 @@ def preemptively_load_shard(request_id: str, opaque_status: str):
|
|
traceback.print_exc()
|
|
traceback.print_exc()
|
|
node.on_opaque_status.register("preemptively_load_shard").on_next(preemptively_load_shard)
|
|
node.on_opaque_status.register("preemptively_load_shard").on_next(preemptively_load_shard)
|
|
|
|
|
|
-last_broadcast_time = 0
|
|
|
|
|
|
+last_events: dict[str, tuple[float, RepoProgressEvent]] = {}
|
|
def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
|
|
def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
|
|
- global last_broadcast_time
|
|
|
|
|
|
+ global last_events
|
|
current_time = time.time()
|
|
current_time = time.time()
|
|
if event.status == "not_started": return
|
|
if event.status == "not_started": return
|
|
- if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
|
|
|
|
- last_broadcast_time = current_time
|
|
|
|
- asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
|
|
|
|
|
|
+ last_event = last_events.get(shard.model_id)
|
|
|
|
+ if last_event and last_event[1].status == "complete" and event.status == "complete": return
|
|
|
|
+ if last_event and last_event[0] == event.status and current_time - last_event[0] < 0.2: return
|
|
|
|
+ last_events[shard.model_id] = (current_time, event)
|
|
|
|
+ asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
|
|
shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
|
|
shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
|
|
|
|
|
|
async def run_model_cli(node: Node, model_name: str, prompt: str):
|
|
async def run_model_cli(node: Node, model_name: str, prompt: str):
|