|
@@ -177,7 +177,7 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
|
|
|
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
|
|
|
if DEBUG >= 2: print(f"Downloaded: {file_path}")
|
|
|
|
|
|
-async def download_repo_files(repo_id: str, revision: str = "main", progress_callback: Optional[RepoProgressCallback] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None) -> Path:
|
|
|
+async def download_repo_files(repo_id: str, revision: str = "main", progress_callback: Optional[RepoProgressCallback] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None, max_parallel_downloads: int = 4) -> Path:
|
|
|
repo_root = get_repo_root(repo_id)
|
|
|
refs_dir = repo_root / "refs"
|
|
|
snapshots_dir = repo_root / "snapshots"
|
|
@@ -236,7 +236,12 @@ async def download_repo_files(repo_id: str, revision: str = "main", progress_cal
|
|
|
await progress_callback(RepoProgressEvent(repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
|
|
|
|
|
|
progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0}
|
|
|
- tasks = [download_with_progress(file_info, progress_state) for file_info in filtered_file_list]
|
|
|
+
|
|
|
+ semaphore = asyncio.Semaphore(max_parallel_downloads)
|
|
|
+ async def download_with_semaphore(file_info):
|
|
|
+ async with semaphore:
|
|
|
+ await download_with_progress(file_info, progress_state)
|
|
|
+ tasks = [asyncio.create_task(download_with_semaphore(file_info)) for file_info in filtered_file_list]
|
|
|
await asyncio.gather(*tasks)
|
|
|
|
|
|
return snapshot_dir
|