|
@@ -7,7 +7,9 @@ from exo.download.shard_download import ShardDownloader
|
|
|
from exo.download.download_progress import RepoProgressEvent
|
|
|
from exo.download.hf.hf_helpers import (
|
|
|
download_repo_files, RepoProgressEvent, get_weight_map,
|
|
|
- get_allow_patterns, get_repo_root, fetch_file_list, get_local_snapshot_dir
|
|
|
+ get_allow_patterns, get_repo_root, fetch_file_list,
|
|
|
+ get_local_snapshot_dir, get_file_download_percentage,
|
|
|
+ filter_repo_objects
|
|
|
)
|
|
|
from exo.helpers import AsyncCallbackSystem, DEBUG
|
|
|
from exo.models import model_cards, get_repo
|
|
@@ -94,6 +96,7 @@ class HFShardDownloader(ShardDownloader):
|
|
|
return None
|
|
|
|
|
|
try:
|
|
|
+ # If no snapshot directory exists, return None - no need to check remote files
|
|
|
snapshot_dir = await get_local_snapshot_dir(self.current_repo_id, self.revision)
|
|
|
if not snapshot_dir:
|
|
|
if DEBUG >= 2: print(f"No snapshot directory found for {self.current_repo_id}")
|
|
@@ -105,32 +108,44 @@ class HFShardDownloader(ShardDownloader):
|
|
|
if DEBUG >= 2: print(f"No weight map found for {self.current_repo_id}")
|
|
|
return None
|
|
|
|
|
|
- # Get the patterns for this shard
|
|
|
+ # Get all files needed for this shard
|
|
|
patterns = get_allow_patterns(weight_map, self.current_shard)
|
|
|
|
|
|
- # First check which files exist locally
|
|
|
+ # Check download status for all relevant files
|
|
|
status = {}
|
|
|
- local_files = []
|
|
|
- local_sizes = {}
|
|
|
+ total_bytes = 0
|
|
|
+ downloaded_bytes = 0
|
|
|
|
|
|
- for pattern in patterns:
|
|
|
- if pattern.endswith('safetensors') or pattern.endswith('mlx'):
|
|
|
- file_path = snapshot_dir / pattern
|
|
|
- if await aios.path.exists(file_path):
|
|
|
- local_size = await aios.path.getsize(file_path)
|
|
|
- local_files.append(pattern)
|
|
|
- local_sizes[pattern] = local_size
|
|
|
-
|
|
|
- # Only fetch remote info if we found local files
|
|
|
- if local_files:
|
|
|
- async with aiohttp.ClientSession() as session:
|
|
|
- file_list = await fetch_file_list(session, self.current_repo_id, self.revision)
|
|
|
+ async with aiohttp.ClientSession() as session:
|
|
|
+ file_list = await fetch_file_list(session, self.current_repo_id, self.revision)
|
|
|
+ relevant_files = list(filter_repo_objects(file_list, allow_patterns=patterns, key=lambda x: x["path"]))
|
|
|
+
|
|
|
+ for file in relevant_files:
|
|
|
+ file_size = file["size"]
|
|
|
+ total_bytes += file_size
|
|
|
|
|
|
- for pattern in local_files:
|
|
|
- for file in file_list:
|
|
|
- if file["path"].endswith(pattern):
|
|
|
- status[pattern] = (local_sizes[pattern] / file["size"]) * 100
|
|
|
- break
|
|
|
+ percentage = await get_file_download_percentage(
|
|
|
+ session,
|
|
|
+ self.current_repo_id,
|
|
|
+ self.revision,
|
|
|
+ file["path"],
|
|
|
+ snapshot_dir
|
|
|
+ )
|
|
|
+ status[file["path"]] = percentage
|
|
|
+ downloaded_bytes += (file_size * (percentage / 100))
|
|
|
+
|
|
|
+ # Add overall progress weighted by file size
|
|
|
+ if total_bytes > 0:
|
|
|
+ status["overall"] = (downloaded_bytes / total_bytes) * 100
|
|
|
+ else:
|
|
|
+ status["overall"] = 0
|
|
|
+
|
|
|
+ if DEBUG >= 2:
|
|
|
+ print(f"Download calculation for {self.current_repo_id}:")
|
|
|
+ print(f"Total bytes: {total_bytes}")
|
|
|
+ print(f"Downloaded bytes: {downloaded_bytes}")
|
|
|
+ for file in relevant_files:
|
|
|
+ print(f"File {file['path']}: size={file['size']}, percentage={status[file['path']]}")
|
|
|
|
|
|
return status
|
|
|
|