Prechádzať zdrojové kódy

make max_parallel_downloads configurable, increase download chunk size to 8MB

Alex Cheema 2 mesiacov pred
rodič
commit
477e3a5e4c
2 zmenil súbory, kde vykonal 10 pridanie a 9 odobranie
  1. 8 7
      exo/download/new_shard_download.py
  2. 2 2
      exo/main.py

+ 8 - 7
exo/download/new_shard_download.py

@@ -112,7 +112,7 @@ async def calc_hash(path: Path, type: Literal["sha1", "sha256"] = "sha1") -> str
     header = f"blob {(await aios.stat(path)).st_size}\0".encode()
     hash.update(header)
   async with aiofiles.open(path, 'rb') as f:
-    while chunk := await f.read(1024 * 1024):
+    while chunk := await f.read(8 * 1024 * 1024):
       hash.update(chunk)
   return hash.hexdigest()
 
@@ -154,7 +154,7 @@ async def _download_file(repo_id: str, revision: str, path: str, target_dir: Pat
         if r.status == 404: raise FileNotFoundError(f"File not found: {url}")
         assert r.status in [200, 206], f"Failed to download {path} from {url}: {r.status}"
         async with aiofiles.open(partial_path, 'ab' if resume_byte_pos else 'wb') as f:
-          while chunk := await r.content.read(1024 * 1024): on_progress(n_read := n_read + await f.write(chunk), length)
+          while chunk := await r.content.read(8 * 1024 * 1024): on_progress(n_read := n_read + await f.write(chunk), length)
 
   final_hash = await calc_hash(partial_path, type="sha256" if len(remote_hash) == 64 else "sha1")
   integrity = final_hash == remote_hash
@@ -197,7 +197,7 @@ async def get_downloaded_size(path: Path) -> int:
   if await aios.path.exists(partial_path): return (await aios.stat(partial_path)).st_size
   return 0
 
-async def download_shard(shard: Shard, inference_engine_classname: str, on_progress: AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]], max_parallel_downloads: int = 6, skip_download: bool = False) -> tuple[Path, RepoProgressEvent]:
+async def download_shard(shard: Shard, inference_engine_classname: str, on_progress: AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]], max_parallel_downloads: int = 8, skip_download: bool = False) -> tuple[Path, RepoProgressEvent]:
   if DEBUG >= 2 and not skip_download: print(f"Downloading {shard.model_id=} for {inference_engine_classname}")
   repo_id = get_repo(shard.model_id, inference_engine_classname)
   revision = "main"
@@ -238,8 +238,8 @@ async def download_shard(shard: Shard, inference_engine_classname: str, on_progr
   else:
     return target_dir, final_repo_progress
 
-def new_shard_downloader() -> ShardDownloader:
-  return SingletonShardDownloader(CachedShardDownloader(NewShardDownloader()))
+def new_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
+  return SingletonShardDownloader(CachedShardDownloader(NewShardDownloader(max_parallel_downloads)))
 
 class SingletonShardDownloader(ShardDownloader):
   def __init__(self, shard_downloader: ShardDownloader):
@@ -283,7 +283,8 @@ class CachedShardDownloader(ShardDownloader):
       yield path, status
 
 class NewShardDownloader(ShardDownloader):
-  def __init__(self):
+  def __init__(self, max_parallel_downloads: int = 8):
+    self.max_parallel_downloads = max_parallel_downloads
     self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
 
   @property
@@ -291,7 +292,7 @@ class NewShardDownloader(ShardDownloader):
     return self._on_progress
 
   async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
-    target_dir, _ = await download_shard(shard, inference_engine_name, self.on_progress)
+    target_dir, _ = await download_shard(shard, inference_engine_name, self.on_progress, max_parallel_downloads=self.max_parallel_downloads)
     return target_dir
 
   async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:

+ 2 - 2
exo/main.py

@@ -72,7 +72,7 @@ parser.add_argument("--node-port", type=int, default=None, help="Node port")
 parser.add_argument("--models-seed-dir", type=str, default=None, help="Model seed directory")
 parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
 parser.add_argument("--download-quick-check", action="store_true", help="Quick check local path for model shards download")
-parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download")
+parser.add_argument("--max-parallel-downloads", type=int, default=8, help="Max parallel downloads for model shards download")
 parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
 parser.add_argument("--discovery-module", type=str, choices=["udp", "tailscale", "manual"], default="udp", help="Discovery module to use")
 parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery timeout in seconds")
@@ -99,7 +99,7 @@ print_yellow_exo()
 system_info = get_system_info()
 print(f"Detected system: {system_info}")
 
-shard_downloader: ShardDownloader = new_shard_downloader() if args.inference_engine != "dummy" else NoopShardDownloader()
+shard_downloader: ShardDownloader = new_shard_downloader(args.max_parallel_downloads) if args.inference_engine != "dummy" else NoopShardDownloader()
 inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
 print(f"Inference engine name after selection: {inference_engine_name}")