Преглед изворни кода

adding redirect for all requests

cadenmackenzie пре 8 месеци
родитељ
комит
3ac868729e
1 измењених фајлова са 17 додато и 26 уклоњено
  1. 17 26
      exo/download/hf/hf_helpers.py

+ 17 - 26
exo/download/hf/hf_helpers.py

@@ -448,33 +448,24 @@ async def get_file_download_percentage(
         url = urljoin(base_url, file_path)
         headers = await get_auth_headers()
         
-        # For safetensors files, we need to follow redirects and use GET instead of HEAD
-        if file_path.endswith('.safetensors'):
-            async with session.get(url, headers=headers, allow_redirects=True) as response:
-                if response.status != 200:
-                    if DEBUG >= 2: print(f"Failed to get remote file info for {file_path}: {response.status}")
-                    return 0
-                    
-                remote_size = int(response.headers.get('Content-Length', 0))
-                # Don't download the actual file, just get the headers
-                await response.release()
-        else:
-            async with session.head(url, headers=headers) as response:
-                if response.status != 200:
-                    if DEBUG >= 2: print(f"Failed to get remote file info for {file_path}: {response.status}")
-                    return 0
-                remote_size = int(response.headers.get('Content-Length', 0))
-        
-        if remote_size == 0:
-            if DEBUG >= 2: print(f"Remote size is 0 for {file_path}")
-            return 0
-            
-        # Only return 100% if sizes match exactly
-        if local_size == remote_size:
-            return 100.0
+        # Use HEAD request with redirect following for all files
+        async with session.head(url, headers=headers, allow_redirects=True) as response:
+            if response.status != 200:
+                if DEBUG >= 2: print(f"Failed to get remote file info for {file_path}: {response.status}")
+                return 0
+                
+            remote_size = int(response.headers.get('Content-Length', 0))
             
-        # Calculate percentage based on sizes
-        return (local_size / remote_size) * 100 if remote_size > 0 else 0
+            if remote_size == 0:
+                if DEBUG >= 2: print(f"Remote size is 0 for {file_path}")
+                return 0
+                
+            # Only return 100% if sizes match exactly
+            if local_size == remote_size:
+                return 100.0
+                
+            # Calculate percentage based on sizes
+            return (local_size / remote_size) * 100 if remote_size > 0 else 0
         
     except Exception as e:
         if DEBUG >= 2: