Bläddra i källkod

adding helper funciton to check file download. also modifying download_file to use that helper

cadenmackenzie 8 månader sedan
förälder
incheckning
c61f40c64b
1 ändrade filer med 75 tillägg och 6 borttagningar
  1. 75 6
      exo/download/hf/hf_helpers.py

+ 75 - 6
exo/download/hf/hf_helpers.py

@@ -14,6 +14,7 @@ from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEv
 from exo.inference.shard import Shard
 import aiofiles
 from aiofiles import os as aios
+import traceback
 
 T = TypeVar("T")
 
@@ -131,7 +132,7 @@ async def download_file(
 ):
   base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
   url = urljoin(base_url, file_path)
-  local_path = os.path.join(save_directory, file_path)
+  local_path = Path(os.path.join(save_directory, file_path))
 
   await aios.makedirs(os.path.dirname(local_path), exist_ok=True)
 
@@ -147,11 +148,19 @@ async def download_file(
     downloaded_size = local_file_size
     downloaded_this_session = 0
     mode = 'ab' if use_range_request else 'wb'
-    if downloaded_size == total_size:
-      if DEBUG >= 2: print(f"File already downloaded: {file_path}")
-      if progress_callback:
-        await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
-      return
+    percentage = await get_file_download_percentage(
+        session,
+        repo_id,
+        revision,
+        file_path,
+        Path(save_directory)
+    )
+    
+    if percentage == 100:
+        if DEBUG >= 2: print(f"File already downloaded: {file_path}")
+        if progress_callback:
+            await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, total_size, 0, total_size, 0, timedelta(0), "complete"))
+        return
 
     if response.status == 200:
       # File doesn't support range requests or we're not using them, start from beginning
@@ -412,3 +421,63 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
     shard_specific_patterns = set("*.safetensors")
   if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
   return list(default_patterns | shard_specific_patterns)
+
+
+async def get_file_download_percentage(
+    session: aiohttp.ClientSession,
+    repo_id: str,
+    revision: str,
+    file_path: str,
+    snapshot_dir: Path
+) -> float:
+    """
+    Calculate the download percentage for a file by comparing local and remote sizes.
+    
+    Args:
+        session: Active aiohttp session
+        repo_id: The Hugging Face repository ID
+        revision: Repository revision/tag
+        file_path: Path to the file within the repo
+        snapshot_dir: Local directory where files are stored
+        
+    Returns:
+        float: Download percentage (0-100), or 0 if file doesn't exist
+    """
+    try:
+        local_path = snapshot_dir / file_path
+        if not await aios.path.exists(local_path):
+            return 0
+
+        # Get local file size first
+        local_size = await aios.path.getsize(local_path)
+        if local_size == 0:
+            return 0
+            
+        # Check remote file size
+        base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
+        url = urljoin(base_url, file_path)
+        headers = await get_auth_headers()
+        
+        async with session.head(url, headers=headers) as response:
+            if response.status != 200:
+                if DEBUG >= 2: print(f"Failed to get remote file info for {file_path}: {response.status}")
+                return 0
+            remote_size = int(response.headers.get('Content-Length', 0))
+            
+            # If we have a local file and either:
+            # 1. Remote size is 0 (shouldn't happen but just in case)
+            # 2. Local size matches remote size
+            # 3. Local size is greater than 0 and remote size couldn't be determined
+            if remote_size == 0 or local_size == remote_size or (local_size > 0 and remote_size == 0):
+                return 100.0
+                
+            return (local_size / remote_size) * 100 if remote_size > 0 else 0
+        
+    except Exception as e:
+        if DEBUG >= 2: 
+            print(f"Error checking file download status for {file_path}: {e}")
+            traceback.print_exc()
+        # If we have a local file but can't check remote size, assume it's complete
+        if await aios.path.exists(local_path) and await aios.path.getsize(local_path) > 0:
+            return 100.0
+        return 0