Преглед изворни кода

cache completed download paths

Alex Cheema пре 1 година
родитељ
комит
dd41026c5b
1 измењених фајлова са 7 додато и 1 уклоњено
  1. 7 1
      exo/download/hf/hf_shard_download.py

+ 7 - 1
exo/download/hf/hf_shard_download.py

@@ -11,9 +11,13 @@ from exo.helpers import AsyncCallbackSystem, DEBUG
 class HFShardDownloader(ShardDownloader):
     def __init__(self):
         self.active_downloads: Dict[Shard, asyncio.Task] = {}
+        self.completed_downloads: Dict[Shard, Path] = {}
         self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
 
     async def ensure_shard(self, shard: Shard) -> Path:
+        if shard in self.completed_downloads:
+            return self.completed_downloads[shard]
+
         # If a download on this shard is already in progress, keep that one
         for active_shard in self.active_downloads:
             if active_shard == shard:
@@ -39,7 +43,9 @@ class HFShardDownloader(ShardDownloader):
         download_task = asyncio.create_task(self._download_shard(shard))
         self.active_downloads[shard] = download_task
         try:
-            return await download_task
+            path = await download_task
+            self.completed_downloads[shard] = path
+            return path
         finally:
             # Ensure the task is removed even if an exception occurs
             print(f"Removing download task for {shard}: {shard in self.active_downloads}")