Explorar o código

modifying helper fucntion checking size to follow redirect for .safetensor files to properly check the size with GET request

cadenmackenzie hai 8 meses
pai
achega
4c6fda7cab
Modificáronse 1 ficheiros con 27 adicións e 27 borrados
  1. 27 27
      exo/download/hf/hf_helpers.py

+ 27 - 27
exo/download/hf/hf_helpers.py

@@ -432,16 +432,6 @@ async def get_file_download_percentage(
 ) -> float:
     """
     Calculate the download percentage for a file by comparing local and remote sizes.
-    
-    Args:
-        session: Active aiohttp session
-        repo_id: The Hugging Face repository ID
-        revision: Repository revision/tag
-        file_path: Path to the file within the repo
-        snapshot_dir: Local directory where files are stored
-        
-    Returns:
-        float: Download percentage (0-100), or 0 if file doesn't exist
     """
     try:
         local_path = snapshot_dir / file_path
@@ -453,31 +443,41 @@ async def get_file_download_percentage(
         if local_size == 0:
             return 0
             
-        # Check remote file size
+        # Check remote size
         base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
         url = urljoin(base_url, file_path)
         headers = await get_auth_headers()
         
-        async with session.head(url, headers=headers) as response:
-            if response.status != 200:
-                if DEBUG >= 2: print(f"Failed to get remote file info for {file_path}: {response.status}")
-                return 0
-            remote_size = int(response.headers.get('Content-Length', 0))
+        # For safetensors files, we need to follow redirects and use GET instead of HEAD
+        if file_path.endswith('.safetensors'):
+            async with session.get(url, headers=headers, allow_redirects=True) as response:
+                if response.status != 200:
+                    if DEBUG >= 2: print(f"Failed to get remote file info for {file_path}: {response.status}")
+                    return 0
+                    
+                remote_size = int(response.headers.get('Content-Length', 0))
+                # Don't download the actual file, just get the headers
+                await response.release()
+        else:
+            async with session.head(url, headers=headers) as response:
+                if response.status != 200:
+                    if DEBUG >= 2: print(f"Failed to get remote file info for {file_path}: {response.status}")
+                    return 0
+                remote_size = int(response.headers.get('Content-Length', 0))
+        
+        if remote_size == 0:
+            if DEBUG >= 2: print(f"Remote size is 0 for {file_path}")
+            return 0
+            
+        # Only return 100% if sizes match exactly
+        if local_size == remote_size:
+            return 100.0
             
-            # If we have a local file and either:
-            # 1. Remote size is 0 (shouldn't happen but just in case)
-            # 2. Local size matches remote size
-            # 3. Local size is greater than 0 and remote size couldn't be determined
-            if remote_size == 0 or local_size == remote_size or (local_size > 0 and remote_size == 0):
-                return 100.0
-                
-            return (local_size / remote_size) * 100 if remote_size > 0 else 0
+        # Calculate percentage based on sizes
+        return (local_size / remote_size) * 100 if remote_size > 0 else 0
         
     except Exception as e:
         if DEBUG >= 2: 
             print(f"Error checking file download status for {file_path}: {e}")
             traceback.print_exc()
-        # If we have a local file but can't check remote size, assume it's complete
-        if await aios.path.exists(local_path) and await aios.path.getsize(local_path) > 0:
-            return 100.0
         return 0