Переглянути джерело

add --max-parallel-downloads flag that limits the number of downloads at a time with asyncio.semaphore

Alex Cheema 8 місяців тому
батько
коміт
6c1bf127b3
3 змінених файлів з 14 додано та 6 видалено
  1. 7 2
      exo/download/hf/hf_helpers.py
  2. 4 2
      exo/download/hf/hf_shard_download.py
  3. 3 2
      main.py

+ 7 - 2
exo/download/hf/hf_helpers.py

@@ -177,7 +177,7 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
                     await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
         if DEBUG >= 2: print(f"Downloaded: {file_path}")
 
-async def download_repo_files(repo_id: str, revision: str = "main", progress_callback: Optional[RepoProgressCallback] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None) -> Path:
+async def download_repo_files(repo_id: str, revision: str = "main", progress_callback: Optional[RepoProgressCallback] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None, max_parallel_downloads: int = 4) -> Path:
     repo_root = get_repo_root(repo_id)
     refs_dir = repo_root / "refs"
     snapshots_dir = repo_root / "snapshots"
@@ -236,7 +236,12 @@ async def download_repo_files(repo_id: str, revision: str = "main", progress_cal
                 await progress_callback(RepoProgressEvent(repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
 
         progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0}
-        tasks = [download_with_progress(file_info, progress_state) for file_info in filtered_file_list]
+
+        semaphore = asyncio.Semaphore(max_parallel_downloads)
+        async def download_with_semaphore(file_info):
+            async with semaphore:
+                await download_with_progress(file_info, progress_state)
+        tasks = [asyncio.create_task(download_with_semaphore(file_info)) for file_info in filtered_file_list]
         await asyncio.gather(*tasks)
 
     return snapshot_dir

+ 4 - 2
exo/download/hf/hf_shard_download.py

@@ -9,8 +9,9 @@ from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, g
 from exo.helpers import AsyncCallbackSystem, DEBUG
 
 class HFShardDownloader(ShardDownloader):
-    def __init__(self, quick_check: bool = False):
+    def __init__(self, quick_check: bool = False, max_parallel_downloads: int = 4):
         self.quick_check = quick_check
+        self.max_parallel_downloads = max_parallel_downloads
         self.active_downloads: Dict[Shard, asyncio.Task] = {}
         self.completed_downloads: Dict[Shard, Path] = {}
         self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
@@ -69,7 +70,8 @@ class HFShardDownloader(ShardDownloader):
         return await download_repo_files(
             repo_id=shard.model_id,
             progress_callback=wrapped_progress_callback,
-            allow_patterns=allow_patterns
+            allow_patterns=allow_patterns,
+            max_parallel_downloads=self.max_parallel_downloads
         )
 
     @property

+ 3 - 2
main.py

@@ -20,7 +20,8 @@ parser.add_argument("--node-id", type=str, default=None, help="Node ID")
 parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
 parser.add_argument("--node-port", type=int, default=None, help="Node port")
 parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
-parser.add_argument("--download-quick-check", action="store_true", help="Quick check local path for shard download")
+parser.add_argument("--download-quick-check", action="store_true", help="Quick check local path for model shards download")
+parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download")
 parser.add_argument("--prometheus-client-port", type=int, default=None, help="Prometheus client port")
 parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
 parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery timeout in seconds")
@@ -37,7 +38,7 @@ print_yellow_exo()
 system_info = get_system_info()
 print(f"Detected system: {system_info}")
 
-shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check)
+shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check, max_parallel_downloads=args.max_parallel_downloads)
 inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
 inference_engine = get_inference_engine(inference_engine_name, shard_downloader)
 print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")