|
@@ -14,6 +14,7 @@ from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEv
|
|
|
from exo.inference.shard import Shard
|
|
|
import aiofiles
|
|
|
from aiofiles import os as aios
|
|
|
+import traceback
|
|
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
@@ -131,7 +132,7 @@ async def download_file(
|
|
|
):
|
|
|
base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
|
|
|
url = urljoin(base_url, file_path)
|
|
|
- local_path = os.path.join(save_directory, file_path)
|
|
|
+ local_path = Path(os.path.join(save_directory, file_path))
|
|
|
|
|
|
await aios.makedirs(os.path.dirname(local_path), exist_ok=True)
|
|
|
|
|
@@ -147,11 +148,19 @@ async def download_file(
|
|
|
downloaded_size = local_file_size
|
|
|
downloaded_this_session = 0
|
|
|
mode = 'ab' if use_range_request else 'wb'
|
|
|
- if downloaded_size == total_size:
|
|
|
- if DEBUG >= 2: print(f"File already downloaded: {file_path}")
|
|
|
- if progress_callback:
|
|
|
- await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
|
|
|
- return
|
|
|
+ percentage = await get_file_download_percentage(
|
|
|
+ session,
|
|
|
+ repo_id,
|
|
|
+ revision,
|
|
|
+ file_path,
|
|
|
+ Path(save_directory)
|
|
|
+ )
|
|
|
+
|
|
|
+ if percentage == 100:
|
|
|
+ if DEBUG >= 2: print(f"File already downloaded: {file_path}")
|
|
|
+ if progress_callback:
|
|
|
+ await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, total_size, 0, total_size, 0, timedelta(0), "complete"))
|
|
|
+ return
|
|
|
|
|
|
if response.status == 200:
|
|
|
# File doesn't support range requests or we're not using them, start from beginning
|
|
@@ -412,3 +421,63 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
|
|
|
shard_specific_patterns = set("*.safetensors")
|
|
|
if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
|
|
|
return list(default_patterns | shard_specific_patterns)
|
|
|
+
|
|
|
+
|
|
|
+async def get_file_download_percentage(
|
|
|
+ session: aiohttp.ClientSession,
|
|
|
+ repo_id: str,
|
|
|
+ revision: str,
|
|
|
+ file_path: str,
|
|
|
+ snapshot_dir: Path
|
|
|
+) -> float:
|
|
|
+ """
|
|
|
+ Calculate the download percentage for a file by comparing local and remote sizes.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ session: Active aiohttp session
|
|
|
+ repo_id: The Hugging Face repository ID
|
|
|
+ revision: Repository revision/tag
|
|
|
+ file_path: Path to the file within the repo
|
|
|
+ snapshot_dir: Local directory where files are stored
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ float: Download percentage (0-100), or 0 if file doesn't exist
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ local_path = snapshot_dir / file_path
|
|
|
+ if not await aios.path.exists(local_path):
|
|
|
+ return 0
|
|
|
+
|
|
|
+ # Get local file size first
|
|
|
+ local_size = await aios.path.getsize(local_path)
|
|
|
+ if local_size == 0:
|
|
|
+ return 0
|
|
|
+
|
|
|
+ # Check remote file size
|
|
|
+ base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
|
|
|
+ url = urljoin(base_url, file_path)
|
|
|
+ headers = await get_auth_headers()
|
|
|
+
|
|
|
+ async with session.head(url, headers=headers) as response:
|
|
|
+ if response.status != 200:
|
|
|
+ if DEBUG >= 2: print(f"Failed to get remote file info for {file_path}: {response.status}")
|
|
|
+ return 0
|
|
|
+ remote_size = int(response.headers.get('Content-Length', 0))
|
|
|
+
|
|
|
+ # If we have a local file and either:
|
|
|
+ # 1. Remote size is 0 (shouldn't happen but just in case)
|
|
|
+ # 2. Local size matches remote size
|
|
|
+ # 3. Local size is greater than 0 and remote size couldn't be determined
|
|
|
+ if remote_size == 0 or local_size == remote_size or (local_size > 0 and remote_size == 0):
|
|
|
+ return 100.0
|
|
|
+
|
|
|
+ return (local_size / remote_size) * 100 if remote_size > 0 else 0
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ if DEBUG >= 2:
|
|
|
+ print(f"Error checking file download status for {file_path}: {e}")
|
|
|
+ traceback.print_exc()
|
|
|
+ # If we have a local file but can't check remote size, assume it's complete
|
|
|
+ if await aios.path.exists(local_path) and await aios.path.getsize(local_path) > 0:
|
|
|
+ return 100.0
|
|
|
+ return 0
|