|
@@ -15,15 +15,16 @@ class HFShardDownloader(ShardDownloader):
|
|
|
|
|
|
async def ensure_shard(self, shard: Shard) -> Path:
|
|
async def ensure_shard(self, shard: Shard) -> Path:
|
|
# If a download on this shard is already in progress, keep that one
|
|
# If a download on this shard is already in progress, keep that one
|
|
- for active_shard, task in self.active_downloads.values():
|
|
|
|
|
|
+ for active_shard in self.active_downloads:
|
|
if active_shard == shard:
|
|
if active_shard == shard:
|
|
if DEBUG >= 2: print(f"Download already in progress for {shard}. Keeping that one.")
|
|
if DEBUG >= 2: print(f"Download already in progress for {shard}. Keeping that one.")
|
|
- return await task
|
|
|
|
|
|
+ return await self.active_downloads[shard]
|
|
|
|
|
|
# Cancel any downloads for this model_id on a different shard
|
|
# Cancel any downloads for this model_id on a different shard
|
|
existing_active_shards = [active_shard for active_shard in self.active_downloads.keys() if active_shard.model_id == shard.model_id]
|
|
existing_active_shards = [active_shard for active_shard in self.active_downloads.keys() if active_shard.model_id == shard.model_id]
|
|
for active_shard in existing_active_shards:
|
|
for active_shard in existing_active_shards:
|
|
if DEBUG >= 2: print(f"Cancelling download for {active_shard} (replacing with {shard})")
|
|
if DEBUG >= 2: print(f"Cancelling download for {active_shard} (replacing with {shard})")
|
|
|
|
+ task = self.active_downloads[active_shard]
|
|
task.cancel()
|
|
task.cancel()
|
|
try:
|
|
try:
|
|
await task
|
|
await task
|