|
@@ -11,9 +11,13 @@ from exo.helpers import AsyncCallbackSystem, DEBUG
|
|
|
class HFShardDownloader(ShardDownloader):
|
|
|
def __init__(self):
|
|
|
self.active_downloads: Dict[Shard, asyncio.Task] = {}
|
|
|
+ self.completed_downloads: Dict[Shard, Path] = {}
|
|
|
self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
|
|
|
|
|
|
async def ensure_shard(self, shard: Shard) -> Path:
|
|
|
+ if shard in self.completed_downloads:
|
|
|
+ return self.completed_downloads[shard]
|
|
|
+
|
|
|
# If a download on this shard is already in progress, keep that one
|
|
|
for active_shard in self.active_downloads:
|
|
|
if active_shard == shard:
|
|
@@ -39,7 +43,9 @@ class HFShardDownloader(ShardDownloader):
|
|
|
download_task = asyncio.create_task(self._download_shard(shard))
|
|
|
self.active_downloads[shard] = download_task
|
|
|
try:
|
|
|
- return await download_task
|
|
|
+ path = await download_task
|
|
|
+ self.completed_downloads[shard] = path
|
|
|
+ return path
|
|
|
finally:
|
|
|
# Ensure the task is removed even if an exception occurs
|
|
|
print(f"Removing download task for {shard}: {shard in self.active_downloads}")
|