Browse Source

Merge pull request #4 from cadenmackenzie/hf_helperRefactor

Hf helper refactor
Caden MacKenzie 8 months ago
parent
commit
fad0591f29
3 changed files with 103 additions and 35 deletions
  1. 4 8
      exo/api/chatgpt_api.py
  2. 62 5
      exo/download/hf/hf_helpers.py
  3. 37 22
      exo/download/hf/hf_shard_download.py

+ 4 - 8
exo/api/chatgpt_api.py

@@ -231,14 +231,10 @@ class ChatGPTAPI:
                         if DEBUG >= 2:
                             print(f"Download status for {model_name}: {status}")
                         
-                        # Calculate overall percentage if we have status
-                        download_percentage = None
-                        if status:
-                            percentages = list(status.values())
-                            if percentages:
-                                download_percentage = sum(percentages) / len(percentages)
-                                if DEBUG >= 2:
-                                    print(f"Calculated download percentage for {model_name}: {download_percentage}")
+                        # Get overall percentage from status
+                        download_percentage = status.get("overall") if status else None
+                        if DEBUG >= 2 and download_percentage is not None:
+                            print(f"Overall download percentage for {model_name}: {download_percentage}")
                         
                         model_pool[model_name] = {
                             "name": pretty,

+ 62 - 5
exo/download/hf/hf_helpers.py

@@ -147,11 +147,19 @@ async def download_file(
     downloaded_size = local_file_size
     downloaded_this_session = 0
     mode = 'ab' if use_range_request else 'wb'
-    if downloaded_size == total_size:
-      if DEBUG >= 2: print(f"File already downloaded: {file_path}")
-      if progress_callback:
-        await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
-      return
+    percentage = await get_file_download_percentage(
+        session,
+        repo_id,
+        revision,
+        file_path,
+        Path(save_directory)
+    )
+    
+    if percentage == 100:
+        if DEBUG >= 2: print(f"File already downloaded: {file_path}")
+        if progress_callback:
+            await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, total_size, 0, total_size, 0, timedelta(0), "complete"))
+        return
 
     if response.status == 200:
       # File doesn't support range requests or we're not using them, start from beginning
@@ -412,3 +420,52 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
     shard_specific_patterns = set("*.safetensors")
   if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
   return list(default_patterns | shard_specific_patterns)
+
+
+async def get_file_download_percentage(
+    session: aiohttp.ClientSession,
+    repo_id: str,
+    revision: str,
+    file_path: str,
+    snapshot_dir: Path
+) -> float:
+    """
+    Calculate the download percentage for a file by comparing local and remote sizes.
+    """
+    try:
+        local_path = snapshot_dir / file_path
+        if not await aios.path.exists(local_path):
+            return 0
+
+        # Get local file size first
+        local_size = await aios.path.getsize(local_path)
+        if local_size == 0:
+            return 0
+            
+        # Check remote size
+        base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
+        url = urljoin(base_url, file_path)
+        headers = await get_auth_headers()
+        
+        # Use HEAD request with redirect following for all files
+        async with session.head(url, headers=headers, allow_redirects=True) as response:
+            if response.status != 200:
+                if DEBUG >= 2: print(f"Failed to get remote file info for {file_path}: {response.status}")
+                return 0
+                
+            remote_size = int(response.headers.get('Content-Length', 0))
+            
+            if remote_size == 0:
+                if DEBUG >= 2: print(f"Remote size is 0 for {file_path}")
+                return 0
+                
+            # Only return 100% if sizes match exactly
+            if local_size == remote_size:
+                return 100.0
+                
+            # Calculate percentage based on sizes
+            return (local_size / remote_size) * 100 if remote_size > 0 else 0
+        
+    except Exception as e:
+        if DEBUG >= 2: print(f"Error checking file download status for {file_path}: {e}")
+        return 0

+ 37 - 22
exo/download/hf/hf_shard_download.py

@@ -7,7 +7,9 @@ from exo.download.shard_download import ShardDownloader
 from exo.download.download_progress import RepoProgressEvent
 from exo.download.hf.hf_helpers import (
     download_repo_files, RepoProgressEvent, get_weight_map, 
-    get_allow_patterns, get_repo_root, fetch_file_list, get_local_snapshot_dir
+    get_allow_patterns, get_repo_root, fetch_file_list, 
+    get_local_snapshot_dir, get_file_download_percentage,
+    filter_repo_objects
 )
 from exo.helpers import AsyncCallbackSystem, DEBUG
 from exo.models import model_cards, get_repo
@@ -94,6 +96,7 @@ class HFShardDownloader(ShardDownloader):
         return None
             
     try:
+        # If no snapshot directory exists, return None - no need to check remote files
         snapshot_dir = await get_local_snapshot_dir(self.current_repo_id, self.revision)
         if not snapshot_dir:
             if DEBUG >= 2: print(f"No snapshot directory found for {self.current_repo_id}")
@@ -105,32 +108,44 @@ class HFShardDownloader(ShardDownloader):
             if DEBUG >= 2: print(f"No weight map found for {self.current_repo_id}")
             return None
         
-        # Get the patterns for this shard
+        # Get all files needed for this shard
         patterns = get_allow_patterns(weight_map, self.current_shard)
         
-        # First check which files exist locally
+        # Check download status for all relevant files
         status = {}
-        local_files = []
-        local_sizes = {}
+        total_bytes = 0
+        downloaded_bytes = 0
         
-        for pattern in patterns:
-            if pattern.endswith('safetensors') or pattern.endswith('mlx'):
-                file_path = snapshot_dir / pattern
-                if await aios.path.exists(file_path):
-                    local_size = await aios.path.getsize(file_path)
-                    local_files.append(pattern)
-                    local_sizes[pattern] = local_size
-
-        # Only fetch remote info if we found local files
-        if local_files:
-            async with aiohttp.ClientSession() as session:
-                file_list = await fetch_file_list(session, self.current_repo_id, self.revision)
+        async with aiohttp.ClientSession() as session:
+            file_list = await fetch_file_list(session, self.current_repo_id, self.revision)
+            relevant_files = list(filter_repo_objects(file_list, allow_patterns=patterns, key=lambda x: x["path"]))
+            
+            for file in relevant_files:
+                file_size = file["size"]
+                total_bytes += file_size
                 
-                for pattern in local_files:
-                    for file in file_list:
-                        if file["path"].endswith(pattern):
-                            status[pattern] = (local_sizes[pattern] / file["size"]) * 100
-                            break
+                percentage = await get_file_download_percentage(
+                    session,
+                    self.current_repo_id,
+                    self.revision,
+                    file["path"],
+                    snapshot_dir
+                )
+                status[file["path"]] = percentage
+                downloaded_bytes += (file_size * (percentage / 100))
+            
+            # Add overall progress weighted by file size
+            if total_bytes > 0:
+                status["overall"] = (downloaded_bytes / total_bytes) * 100
+            else:
+                status["overall"] = 0
+
+            if DEBUG >= 2:
+                print(f"Download calculation for {self.current_repo_id}:")
+                print(f"Total bytes: {total_bytes}")
+                print(f"Downloaded bytes: {downloaded_bytes}")
+                for file in relevant_files:
+                    print(f"File {file['path']}: size={file['size']}, percentage={status[file['path']]}")
 
         return status