فهرست منبع

removing checking of percentage for models that are not found locally

cadenmackenzie 9 ماه پیش
والد
کامیت
dd38924e39
1فایلهای تغییر یافته به همراه23 افزوده شده و 25 حذف شده
  1. 23 25
      exo/download/hf/hf_shard_download.py

+ 23 - 25
exo/download/hf/hf_shard_download.py

@@ -104,40 +104,38 @@ class HFShardDownloader(ShardDownloader):
         if not weight_map:
             if DEBUG >= 2: print(f"No weight map found for {self.current_repo_id}")
             return None
-
-        if DEBUG >= 2: print(f"Checking download status for {self.current_repo_id}: {weight_map=}")
         
         # Get the patterns for this shard
         patterns = get_allow_patterns(weight_map, self.current_shard)
         
-        # Check each required file's status
+        # First check which files exist locally
         status = {}
-        async with aiohttp.ClientSession() as session:
-            file_list = await fetch_file_list(session, self.current_repo_id, self.revision)
-            
-            for pattern in patterns:
-                if pattern.endswith('safetensors') or pattern.endswith('mlx'):
-                    expected_size = None
-                    local_size = 0
-                    
-                    # Get expected file size from repo
-                    file_path = snapshot_dir / pattern
-                    if await aios.path.exists(file_path):
-                        local_size = await aios.path.getsize(file_path)
-                        
-                    # Find matching file in file_list
+        local_files = []
+        local_sizes = {}
+        
+        for pattern in patterns:
+            if pattern.endswith('safetensors') or pattern.endswith('mlx'):
+                file_path = snapshot_dir / pattern
+                if await aios.path.exists(file_path):
+                    local_size = await aios.path.getsize(file_path)
+                    local_files.append(pattern)
+                    local_sizes[pattern] = local_size
+
+        # Only fetch remote info if we found local files
+        if local_files:
+            async with aiohttp.ClientSession() as session:
+                file_list = await fetch_file_list(session, self.current_repo_id, self.revision)
+                
+                for pattern in local_files:
                     for file in file_list:
                         if file["path"].endswith(pattern):
-                            expected_size = file["size"]
+                            status[pattern] = (local_sizes[pattern] / file["size"]) * 100
                             break
-                    
-                    if expected_size:
-                        status[pattern] = (local_size / expected_size) * 100
 
         return status
       
     except Exception as e:
-      if DEBUG >= 2:
-        print(f"Error getting shard download status: {e}")
-        traceback.print_exc()
-      return None
+        if DEBUG >= 2:
+            print(f"Error getting shard download status: {e}")
+            traceback.print_exc()
+        return None