|
@@ -134,7 +134,7 @@ def preemptively_start_download(request_id: str, opaque_status: str):
|
|
|
if status.get("type") == "node_status" and status.get("status") == "start_process_prompt":
|
|
|
current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
|
|
|
if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
|
|
|
- asyncio.create_task(shard_downloader.ensure_shard(current_shard))
|
|
|
+ asyncio.create_task(shard_downloader.ensure_shard(current_shard, inference_engine.__class__.__name__))
|
|
|
except Exception as e:
|
|
|
if DEBUG >= 2:
|
|
|
print(f"Failed to preemptively start download: {e}")
|