Explorar o código

modify get_shard_download_status to use helper function

cadenmackenzie hai 8 meses
pai
achega
dec79ac747
Modificáronse 1 ficheiros con 36 adicións e 22 borrados
  1. 36 22
      exo/download/hf/hf_shard_download.py

+ 36 - 22
exo/download/hf/hf_shard_download.py

@@ -7,7 +7,9 @@ from exo.download.shard_download import ShardDownloader
 from exo.download.download_progress import RepoProgressEvent
 from exo.download.hf.hf_helpers import (
     download_repo_files, RepoProgressEvent, get_weight_map, 
-    get_allow_patterns, get_repo_root, fetch_file_list, get_local_snapshot_dir
+    get_allow_patterns, get_repo_root, fetch_file_list, 
+    get_local_snapshot_dir, get_file_download_percentage,
+    filter_repo_objects
 )
 from exo.helpers import AsyncCallbackSystem, DEBUG
 from exo.models import model_cards, get_repo
@@ -105,32 +107,44 @@ class HFShardDownloader(ShardDownloader):
             if DEBUG >= 2: print(f"No weight map found for {self.current_repo_id}")
             return None
         
-        # Get the patterns for this shard
+        # Get all files needed for this shard
         patterns = get_allow_patterns(weight_map, self.current_shard)
         
-        # First check which files exist locally
+        # Check download status for all relevant files
         status = {}
-        local_files = []
-        local_sizes = {}
+        total_bytes = 0
+        downloaded_bytes = 0
         
-        for pattern in patterns:
-            if pattern.endswith('safetensors') or pattern.endswith('mlx'):
-                file_path = snapshot_dir / pattern
-                if await aios.path.exists(file_path):
-                    local_size = await aios.path.getsize(file_path)
-                    local_files.append(pattern)
-                    local_sizes[pattern] = local_size
-
-        # Only fetch remote info if we found local files
-        if local_files:
-            async with aiohttp.ClientSession() as session:
-                file_list = await fetch_file_list(session, self.current_repo_id, self.revision)
+        async with aiohttp.ClientSession() as session:
+            file_list = await fetch_file_list(session, self.current_repo_id, self.revision)
+            relevant_files = list(filter_repo_objects(file_list, allow_patterns=patterns, key=lambda x: x["path"]))
+            
+            for file in relevant_files:
+                file_size = file["size"]
+                total_bytes += file_size
                 
-                for pattern in local_files:
-                    for file in file_list:
-                        if file["path"].endswith(pattern):
-                            status[pattern] = (local_sizes[pattern] / file["size"]) * 100
-                            break
+                percentage = await get_file_download_percentage(
+                    session,
+                    self.current_repo_id,
+                    self.revision,
+                    file["path"],
+                    snapshot_dir
+                )
+                status[file["path"]] = percentage
+                downloaded_bytes += (file_size * (percentage / 100))
+            
+            # Add overall progress weighted by file size
+            if total_bytes > 0:
+                status["overall"] = (downloaded_bytes / total_bytes) * 100
+            else:
+                status["overall"] = 0
+
+            if DEBUG >= 2:
+                print(f"Download calculation for {self.current_repo_id}:")
+                print(f"Total bytes: {total_bytes}")
+                print(f"Downloaded bytes: {downloaded_bytes}")
+                for file in relevant_files:
+                    print(f"File {file['path']}: size={file['size']}, percentage={status[file['path']]}")
 
         return status