|
@@ -432,16 +432,6 @@ async def get_file_download_percentage(
|
|
|
) -> float:
|
|
|
"""
|
|
|
Calculate the download percentage for a file by comparing local and remote sizes.
|
|
|
-
|
|
|
- Args:
|
|
|
- session: Active aiohttp session
|
|
|
- repo_id: The Hugging Face repository ID
|
|
|
- revision: Repository revision/tag
|
|
|
- file_path: Path to the file within the repo
|
|
|
- snapshot_dir: Local directory where files are stored
|
|
|
-
|
|
|
- Returns:
|
|
|
- float: Download percentage (0-100), or 0 if file doesn't exist
|
|
|
"""
|
|
|
try:
|
|
|
local_path = snapshot_dir / file_path
|
|
@@ -453,31 +443,41 @@ async def get_file_download_percentage(
|
|
|
if local_size == 0:
|
|
|
return 0
|
|
|
|
|
|
- # Check remote file size
|
|
|
+ # Check remote size
|
|
|
base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
|
|
|
url = urljoin(base_url, file_path)
|
|
|
headers = await get_auth_headers()
|
|
|
|
|
|
- async with session.head(url, headers=headers) as response:
|
|
|
- if response.status != 200:
|
|
|
- if DEBUG >= 2: print(f"Failed to get remote file info for {file_path}: {response.status}")
|
|
|
- return 0
|
|
|
- remote_size = int(response.headers.get('Content-Length', 0))
|
|
|
+ # For safetensors files, we need to follow redirects and use GET instead of HEAD
|
|
|
+ if file_path.endswith('.safetensors'):
|
|
|
+ async with session.get(url, headers=headers, allow_redirects=True) as response:
|
|
|
+ if response.status != 200:
|
|
|
+ if DEBUG >= 2: print(f"Failed to get remote file info for {file_path}: {response.status}")
|
|
|
+ return 0
|
|
|
+
|
|
|
+ remote_size = int(response.headers.get('Content-Length', 0))
|
|
|
+ # Don't download the actual file, just get the headers
|
|
|
+ await response.release()
|
|
|
+ else:
|
|
|
+ async with session.head(url, headers=headers) as response:
|
|
|
+ if response.status != 200:
|
|
|
+ if DEBUG >= 2: print(f"Failed to get remote file info for {file_path}: {response.status}")
|
|
|
+ return 0
|
|
|
+ remote_size = int(response.headers.get('Content-Length', 0))
|
|
|
+
|
|
|
+ if remote_size == 0:
|
|
|
+ if DEBUG >= 2: print(f"Remote size is 0 for {file_path}")
|
|
|
+ return 0
|
|
|
+
|
|
|
+ # Only return 100% if sizes match exactly
|
|
|
+ if local_size == remote_size:
|
|
|
+ return 100.0
|
|
|
|
|
|
- # If we have a local file and either:
|
|
|
- # 1. Remote size is 0 (shouldn't happen but just in case)
|
|
|
- # 2. Local size matches remote size
|
|
|
- # 3. Local size is greater than 0 and remote size couldn't be determined
|
|
|
- if remote_size == 0 or local_size == remote_size or (local_size > 0 and remote_size == 0):
|
|
|
- return 100.0
|
|
|
-
|
|
|
- return (local_size / remote_size) * 100 if remote_size > 0 else 0
|
|
|
+ # Calculate percentage based on sizes
|
|
|
+ return (local_size / remote_size) * 100 if remote_size > 0 else 0
|
|
|
|
|
|
except Exception as e:
|
|
|
if DEBUG >= 2:
|
|
|
print(f"Error checking file download status for {file_path}: {e}")
|
|
|
traceback.print_exc()
|
|
|
- # If we have a local file but can't check remote size, assume it's complete
|
|
|
- if await aios.path.exists(local_path) and await aios.path.getsize(local_path) > 0:
|
|
|
- return 100.0
|
|
|
return 0
|