|
@@ -104,40 +104,38 @@ class HFShardDownloader(ShardDownloader):
|
|
|
if not weight_map:
|
|
|
if DEBUG >= 2: print(f"No weight map found for {self.current_repo_id}")
|
|
|
return None
|
|
|
-
|
|
|
- if DEBUG >= 2: print(f"Checking download status for {self.current_repo_id}: {weight_map=}")
|
|
|
|
|
|
# Get the patterns for this shard
|
|
|
patterns = get_allow_patterns(weight_map, self.current_shard)
|
|
|
|
|
|
- # Check each required file's status
|
|
|
+ # First check which files exist locally
|
|
|
status = {}
|
|
|
- async with aiohttp.ClientSession() as session:
|
|
|
- file_list = await fetch_file_list(session, self.current_repo_id, self.revision)
|
|
|
-
|
|
|
- for pattern in patterns:
|
|
|
- if pattern.endswith('safetensors') or pattern.endswith('mlx'):
|
|
|
- expected_size = None
|
|
|
- local_size = 0
|
|
|
-
|
|
|
- # Get expected file size from repo
|
|
|
- file_path = snapshot_dir / pattern
|
|
|
- if await aios.path.exists(file_path):
|
|
|
- local_size = await aios.path.getsize(file_path)
|
|
|
-
|
|
|
- # Find matching file in file_list
|
|
|
+ local_files = []
|
|
|
+ local_sizes = {}
|
|
|
+
|
|
|
+ for pattern in patterns:
|
|
|
+ if pattern.endswith('safetensors') or pattern.endswith('mlx'):
|
|
|
+ file_path = snapshot_dir / pattern
|
|
|
+ if await aios.path.exists(file_path):
|
|
|
+ local_size = await aios.path.getsize(file_path)
|
|
|
+ local_files.append(pattern)
|
|
|
+ local_sizes[pattern] = local_size
|
|
|
+
|
|
|
+ # Only fetch remote info if we found local files
|
|
|
+ if local_files:
|
|
|
+ async with aiohttp.ClientSession() as session:
|
|
|
+ file_list = await fetch_file_list(session, self.current_repo_id, self.revision)
|
|
|
+
|
|
|
+ for pattern in local_files:
|
|
|
for file in file_list:
|
|
|
if file["path"].endswith(pattern):
|
|
|
- expected_size = file["size"]
|
|
|
+ status[pattern] = (local_sizes[pattern] / file["size"]) * 100
|
|
|
break
|
|
|
-
|
|
|
- if expected_size:
|
|
|
- status[pattern] = (local_size / expected_size) * 100
|
|
|
|
|
|
return status
|
|
|
|
|
|
except Exception as e:
|
|
|
- if DEBUG >= 2:
|
|
|
- print(f"Error getting shard download status: {e}")
|
|
|
- traceback.print_exc()
|
|
|
- return None
|
|
|
+ if DEBUG >= 2:
|
|
|
+ print(f"Error getting shard download status: {e}")
|
|
|
+ traceback.print_exc()
|
|
|
+ return None
|