|
@@ -92,65 +92,70 @@ class HFShardDownloader(ShardDownloader):
|
|
|
|
|
|
async def get_shard_download_status(self) -> Optional[Dict[str, float]]:
|
|
|
if not self.current_shard or not self.current_repo_id:
|
|
|
- if DEBUG >= 2: print(f"No current shard or repo_id set: {self.current_shard=} {self.current_repo_id=}")
|
|
|
- return None
|
|
|
-
|
|
|
+ if DEBUG >= 2:
|
|
|
+ print(f"No current shard or repo_id set: {self.current_shard=} {self.current_repo_id=}")
|
|
|
+ return None
|
|
|
+
|
|
|
try:
|
|
|
- # If no snapshot directory exists, return None - no need to check remote files
|
|
|
- snapshot_dir = await get_local_snapshot_dir(self.current_repo_id, self.revision)
|
|
|
- if not snapshot_dir:
|
|
|
- if DEBUG >= 2: print(f"No snapshot directory found for {self.current_repo_id}")
|
|
|
- return None
|
|
|
-
|
|
|
- # Get the weight map to know what files we need
|
|
|
- weight_map = await get_weight_map(self.current_repo_id, self.revision)
|
|
|
- if not weight_map:
|
|
|
- if DEBUG >= 2: print(f"No weight map found for {self.current_repo_id}")
|
|
|
- return None
|
|
|
-
|
|
|
- # Get all files needed for this shard
|
|
|
- patterns = get_allow_patterns(weight_map, self.current_shard)
|
|
|
-
|
|
|
- # Check download status for all relevant files
|
|
|
- status = {}
|
|
|
- total_bytes = 0
|
|
|
- downloaded_bytes = 0
|
|
|
-
|
|
|
- async with aiohttp.ClientSession() as session:
|
|
|
- file_list = await fetch_file_list(session, self.current_repo_id, self.revision)
|
|
|
- relevant_files = list(filter_repo_objects(file_list, allow_patterns=patterns, key=lambda x: x["path"]))
|
|
|
-
|
|
|
- for file in relevant_files:
|
|
|
- file_size = file["size"]
|
|
|
- total_bytes += file_size
|
|
|
-
|
|
|
- percentage = await get_file_download_percentage(
|
|
|
- session,
|
|
|
- self.current_repo_id,
|
|
|
- self.revision,
|
|
|
- file["path"],
|
|
|
- snapshot_dir
|
|
|
- )
|
|
|
- status[file["path"]] = percentage
|
|
|
- downloaded_bytes += (file_size * (percentage / 100))
|
|
|
-
|
|
|
- # Add overall progress weighted by file size
|
|
|
- if total_bytes > 0:
|
|
|
- status["overall"] = (downloaded_bytes / total_bytes) * 100
|
|
|
- else:
|
|
|
- status["overall"] = 0
|
|
|
-
|
|
|
- if DEBUG >= 2:
|
|
|
- print(f"Download calculation for {self.current_repo_id}:")
|
|
|
- print(f"Total bytes: {total_bytes}")
|
|
|
- print(f"Downloaded bytes: {downloaded_bytes}")
|
|
|
- for file in relevant_files:
|
|
|
- print(f"File {file['path']}: size={file['size']}, percentage={status[file['path']]}")
|
|
|
-
|
|
|
- return status
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
+ # If no snapshot directory exists, return None - no need to check remote files
|
|
|
+ snapshot_dir = await get_local_snapshot_dir(self.current_repo_id, self.revision)
|
|
|
+ if not snapshot_dir:
|
|
|
if DEBUG >= 2:
|
|
|
- print(f"Error getting shard download status: {e}")
|
|
|
- traceback.print_exc()
|
|
|
+ print(f"No snapshot directory found for {self.current_repo_id}")
|
|
|
return None
|
|
|
+
|
|
|
+ # Get the weight map to know what files we need
|
|
|
+ weight_map = await get_weight_map(self.current_repo_id, self.revision)
|
|
|
+ if not weight_map:
|
|
|
+ if DEBUG >= 2:
|
|
|
+ print(f"No weight map found for {self.current_repo_id}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ # Get all files needed for this shard
|
|
|
+ patterns = get_allow_patterns(weight_map, self.current_shard)
|
|
|
+
|
|
|
+ # Check download status for all relevant files
|
|
|
+ status = {}
|
|
|
+ total_bytes = 0
|
|
|
+ downloaded_bytes = 0
|
|
|
+
|
|
|
+ async with aiohttp.ClientSession() as session:
|
|
|
+ file_list = await fetch_file_list(session, self.current_repo_id, self.revision)
|
|
|
+ relevant_files = list(
|
|
|
+ filter_repo_objects(
|
|
|
+ file_list, allow_patterns=patterns, key=lambda x: x["path"]))
|
|
|
+
|
|
|
+ for file in relevant_files:
|
|
|
+ file_size = file["size"]
|
|
|
+ total_bytes += file_size
|
|
|
+
|
|
|
+ percentage = await get_file_download_percentage(
|
|
|
+ session,
|
|
|
+ self.current_repo_id,
|
|
|
+ self.revision,
|
|
|
+ file["path"],
|
|
|
+ snapshot_dir,
|
|
|
+ )
|
|
|
+ status[file["path"]] = percentage
|
|
|
+ downloaded_bytes += (file_size * (percentage / 100))
|
|
|
+
|
|
|
+ # Add overall progress weighted by file size
|
|
|
+ if total_bytes > 0:
|
|
|
+ status["overall"] = (downloaded_bytes / total_bytes) * 100
|
|
|
+ else:
|
|
|
+ status["overall"] = 0
|
|
|
+
|
|
|
+ if DEBUG >= 2:
|
|
|
+ print(f"Download calculation for {self.current_repo_id}:")
|
|
|
+ print(f"Total bytes: {total_bytes}")
|
|
|
+ print(f"Downloaded bytes: {downloaded_bytes}")
|
|
|
+ for file in relevant_files:
|
|
|
+ print(f"File {file['path']}: size={file['size']}, percentage={status[file['path']]}")
|
|
|
+
|
|
|
+ return status
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ if DEBUG >= 2:
|
|
|
+ print(f"Error getting shard download status: {e}")
|
|
|
+ traceback.print_exc()
|
|
|
+ return None
|