浏览代码

fixing formatting

cadenmackenzie 8 月之前
父节点
当前提交
91276ccd4b
共有 1 个文件被更改,包括 64 次插入59 次删除
  1. 64 59
      exo/download/hf/hf_shard_download.py

+ 64 - 59
exo/download/hf/hf_shard_download.py

@@ -92,65 +92,70 @@ class HFShardDownloader(ShardDownloader):
 
   async def get_shard_download_status(self) -> Optional[Dict[str, float]]:
     if not self.current_shard or not self.current_repo_id:
-        if DEBUG >= 2: print(f"No current shard or repo_id set: {self.current_shard=} {self.current_repo_id=}")
-        return None
-            
+      if DEBUG >= 2:
+        print(f"No current shard or repo_id set: {self.current_shard=} {self.current_repo_id=}")
+      return None
+
     try:
-        # If no snapshot directory exists, return None - no need to check remote files
-        snapshot_dir = await get_local_snapshot_dir(self.current_repo_id, self.revision)
-        if not snapshot_dir:
-            if DEBUG >= 2: print(f"No snapshot directory found for {self.current_repo_id}")
-            return None
-
-        # Get the weight map to know what files we need
-        weight_map = await get_weight_map(self.current_repo_id, self.revision)
-        if not weight_map:
-            if DEBUG >= 2: print(f"No weight map found for {self.current_repo_id}")
-            return None
-        
-        # Get all files needed for this shard
-        patterns = get_allow_patterns(weight_map, self.current_shard)
-        
-        # Check download status for all relevant files
-        status = {}
-        total_bytes = 0
-        downloaded_bytes = 0
-        
-        async with aiohttp.ClientSession() as session:
-            file_list = await fetch_file_list(session, self.current_repo_id, self.revision)
-            relevant_files = list(filter_repo_objects(file_list, allow_patterns=patterns, key=lambda x: x["path"]))
-            
-            for file in relevant_files:
-                file_size = file["size"]
-                total_bytes += file_size
-                
-                percentage = await get_file_download_percentage(
-                    session,
-                    self.current_repo_id,
-                    self.revision,
-                    file["path"],
-                    snapshot_dir
-                )
-                status[file["path"]] = percentage
-                downloaded_bytes += (file_size * (percentage / 100))
-            
-            # Add overall progress weighted by file size
-            if total_bytes > 0:
-                status["overall"] = (downloaded_bytes / total_bytes) * 100
-            else:
-                status["overall"] = 0
-
-            if DEBUG >= 2:
-                print(f"Download calculation for {self.current_repo_id}:")
-                print(f"Total bytes: {total_bytes}")
-                print(f"Downloaded bytes: {downloaded_bytes}")
-                for file in relevant_files:
-                    print(f"File {file['path']}: size={file['size']}, percentage={status[file['path']]}")
-
-        return status
-      
-    except Exception as e:
+      # If no snapshot directory exists, return None - no need to check remote files
+      snapshot_dir = await get_local_snapshot_dir(self.current_repo_id, self.revision)
+      if not snapshot_dir:
         if DEBUG >= 2:
-            print(f"Error getting shard download status: {e}")
-            traceback.print_exc()
+          print(f"No snapshot directory found for {self.current_repo_id}")
         return None
+
+      # Get the weight map to know what files we need
+      weight_map = await get_weight_map(self.current_repo_id, self.revision)
+      if not weight_map:
+        if DEBUG >= 2:
+          print(f"No weight map found for {self.current_repo_id}")
+        return None
+
+      # Get all files needed for this shard
+      patterns = get_allow_patterns(weight_map, self.current_shard)
+
+      # Check download status for all relevant files
+      status = {}
+      total_bytes = 0
+      downloaded_bytes = 0
+
+      async with aiohttp.ClientSession() as session:
+        file_list = await fetch_file_list(session, self.current_repo_id, self.revision)
+        relevant_files = list(
+            filter_repo_objects(
+                file_list, allow_patterns=patterns, key=lambda x: x["path"]))
+
+        for file in relevant_files:
+          file_size = file["size"]
+          total_bytes += file_size
+
+          percentage = await get_file_download_percentage(
+              session,
+              self.current_repo_id,
+              self.revision,
+              file["path"],
+              snapshot_dir,
+          )
+          status[file["path"]] = percentage
+          downloaded_bytes += (file_size * (percentage / 100))
+
+        # Add overall progress weighted by file size
+        if total_bytes > 0:
+          status["overall"] = (downloaded_bytes / total_bytes) * 100
+        else:
+          status["overall"] = 0
+
+        if DEBUG >= 2:
+          print(f"Download calculation for {self.current_repo_id}:")
+          print(f"Total bytes: {total_bytes}")
+          print(f"Downloaded bytes: {downloaded_bytes}")
+          for file in relevant_files:
+            print(f"File {file['path']}: size={file['size']}, percentage={status[file['path']]}")
+
+      return status
+
+    except Exception as e:
+      if DEBUG >= 2:
+        print(f"Error getting shard download status: {e}")
+        traceback.print_exc()
+      return None