Explorar o código

fix shard download

Alex Cheema hai 1 ano
pai
achega
7ec660bba6
Modificáronse 2 ficheiros con 3 adicións e 3 borrados
  1. 0 1
      exo/api/chatgpt_api.py
  2. 3 2
      exo/download/hf/hf_shard_download.py

+ 0 - 1
exo/api/chatgpt_api.py

@@ -93,7 +93,6 @@ async def resolve_tokenizer(model_id: str):
     return processor
   except Exception as e:
     if DEBUG >= 4: print(f"Failed to load processor for {model_id}. Error: {e}")
-
     if DEBUG >= 4: print(traceback.format_exc())
 
   try:

+ 3 - 2
exo/download/hf/hf_shard_download.py

@@ -15,15 +15,16 @@ class HFShardDownloader(ShardDownloader):
 
     async def ensure_shard(self, shard: Shard) -> Path:
         # If a download on this shard is already in progress, keep that one
-        for active_shard, task in self.active_downloads.values():
+        for active_shard in self.active_downloads:
             if active_shard == shard:
                 if DEBUG >= 2: print(f"Download already in progress for {shard}. Keeping that one.")
-                return await task
+                return await self.active_downloads[shard]
 
         # Cancel any downloads for this model_id on a different shard
         existing_active_shards = [active_shard for active_shard in self.active_downloads.keys() if active_shard.model_id == shard.model_id]
         for active_shard in existing_active_shards:
             if DEBUG >= 2: print(f"Cancelling download for {active_shard} (replacing with {shard})")
+            task = self.active_downloads[active_shard]
             task.cancel()
             try:
                 await task