|
@@ -112,7 +112,7 @@ async def calc_hash(path: Path, type: Literal["sha1", "sha256"] = "sha1") -> str
|
|
|
header = f"blob {(await aios.stat(path)).st_size}\0".encode()
|
|
|
hash.update(header)
|
|
|
async with aiofiles.open(path, 'rb') as f:
|
|
|
- while chunk := await f.read(1024 * 1024):
|
|
|
+ while chunk := await f.read(8 * 1024 * 1024):
|
|
|
hash.update(chunk)
|
|
|
return hash.hexdigest()
|
|
|
|
|
@@ -154,7 +154,7 @@ async def _download_file(repo_id: str, revision: str, path: str, target_dir: Pat
|
|
|
if r.status == 404: raise FileNotFoundError(f"File not found: {url}")
|
|
|
assert r.status in [200, 206], f"Failed to download {path} from {url}: {r.status}"
|
|
|
async with aiofiles.open(partial_path, 'ab' if resume_byte_pos else 'wb') as f:
|
|
|
- while chunk := await r.content.read(1024 * 1024): on_progress(n_read := n_read + await f.write(chunk), length)
|
|
|
+ while chunk := await r.content.read(8 * 1024 * 1024): on_progress(n_read := n_read + await f.write(chunk), length)
|
|
|
|
|
|
final_hash = await calc_hash(partial_path, type="sha256" if len(remote_hash) == 64 else "sha1")
|
|
|
integrity = final_hash == remote_hash
|
|
@@ -197,7 +197,7 @@ async def get_downloaded_size(path: Path) -> int:
|
|
|
if await aios.path.exists(partial_path): return (await aios.stat(partial_path)).st_size
|
|
|
return 0
|
|
|
|
|
|
-async def download_shard(shard: Shard, inference_engine_classname: str, on_progress: AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]], max_parallel_downloads: int = 6, skip_download: bool = False) -> tuple[Path, RepoProgressEvent]:
|
|
|
+async def download_shard(shard: Shard, inference_engine_classname: str, on_progress: AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]], max_parallel_downloads: int = 8, skip_download: bool = False) -> tuple[Path, RepoProgressEvent]:
|
|
|
if DEBUG >= 2 and not skip_download: print(f"Downloading {shard.model_id=} for {inference_engine_classname}")
|
|
|
repo_id = get_repo(shard.model_id, inference_engine_classname)
|
|
|
revision = "main"
|
|
@@ -238,8 +238,8 @@ async def download_shard(shard: Shard, inference_engine_classname: str, on_progr
|
|
|
else:
|
|
|
return target_dir, final_repo_progress
|
|
|
|
|
|
-def new_shard_downloader() -> ShardDownloader:
|
|
|
- return SingletonShardDownloader(CachedShardDownloader(NewShardDownloader()))
|
|
|
+def new_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
|
|
|
+ return SingletonShardDownloader(CachedShardDownloader(NewShardDownloader(max_parallel_downloads)))
|
|
|
|
|
|
class SingletonShardDownloader(ShardDownloader):
|
|
|
def __init__(self, shard_downloader: ShardDownloader):
|
|
@@ -283,7 +283,8 @@ class CachedShardDownloader(ShardDownloader):
|
|
|
yield path, status
|
|
|
|
|
|
class NewShardDownloader(ShardDownloader):
|
|
|
- def __init__(self):
|
|
|
+ def __init__(self, max_parallel_downloads: int = 8):
|
|
|
+ self.max_parallel_downloads = max_parallel_downloads
|
|
|
self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
|
|
|
|
|
|
@property
|
|
@@ -291,7 +292,7 @@ class NewShardDownloader(ShardDownloader):
|
|
|
return self._on_progress
|
|
|
|
|
|
async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
|
|
|
- target_dir, _ = await download_shard(shard, inference_engine_name, self.on_progress)
|
|
|
+ target_dir, _ = await download_shard(shard, inference_engine_name, self.on_progress, max_parallel_downloads=self.max_parallel_downloads)
|
|
|
return target_dir
|
|
|
|
|
|
async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
|