|
@@ -75,7 +75,7 @@ async def fetch_file_list(repo_id, revision, path=""):
|
|
|
url = f"{api_url}/{path}" if path else api_url
|
|
|
|
|
|
headers = await get_auth_headers()
|
|
|
- async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30, connect=10, sock_read=1800, sock_connect=60)) as session:
|
|
|
+ async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30, connect=10, sock_read=30, sock_connect=10)) as session:
|
|
|
async with session.get(url, headers=headers) as response:
|
|
|
if response.status == 200:
|
|
|
data = await response.json()
|
|
@@ -84,7 +84,7 @@ async def fetch_file_list(repo_id, revision, path=""):
|
|
|
if item["type"] == "file":
|
|
|
files.append({"path": item["path"], "size": item["size"]})
|
|
|
elif item["type"] == "directory":
|
|
|
- subfiles = await fetch_file_list(session, repo_id, revision, item["path"])
|
|
|
+ subfiles = await fetch_file_list(repo_id, revision, item["path"])
|
|
|
files.extend(subfiles)
|
|
|
return files
|
|
|
else:
|
|
@@ -169,17 +169,17 @@ async def download_shard(shard: Shard, inference_engine_classname: str, on_progr
|
|
|
downloaded_bytes = (await aios.stat(target_dir/file["path"])).st_size if await aios.path.exists(target_dir/file["path"]) else 0
|
|
|
file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], downloaded_bytes, 0, file["size"], 0, timedelta(0), "complete" if downloaded_bytes == file["size"] else "not_started", time.time())
|
|
|
|
|
|
- semaphore = asyncio.Semaphore(max_parallel_downloads)
|
|
|
- async def download_with_semaphore(file):
|
|
|
- async with semaphore:
|
|
|
- await download_file(repo_id, revision, file["path"], target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes))
|
|
|
- if not skip_download: await asyncio.gather(*[download_with_semaphore(file) for file in filtered_file_list])
|
|
|
- final_repo_progress = calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time)
|
|
|
- on_progress.trigger_all(shard, final_repo_progress)
|
|
|
- if gguf := next((f for f in filtered_file_list if f["path"].endswith(".gguf")), None):
|
|
|
- return target_dir/gguf["path"], final_repo_progress
|
|
|
- else:
|
|
|
- return target_dir, final_repo_progress
|
|
|
+ semaphore = asyncio.Semaphore(max_parallel_downloads)
|
|
|
+ async def download_with_semaphore(file):
|
|
|
+ async with semaphore:
|
|
|
+ await download_file(repo_id, revision, file["path"], target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes))
|
|
|
+ if not skip_download: await asyncio.gather(*[download_with_semaphore(file) for file in filtered_file_list])
|
|
|
+ final_repo_progress = calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time)
|
|
|
+ on_progress.trigger_all(shard, final_repo_progress)
|
|
|
+ if gguf := next((f for f in filtered_file_list if f["path"].endswith(".gguf")), None):
|
|
|
+ return target_dir/gguf["path"], final_repo_progress
|
|
|
+ else:
|
|
|
+ return target_dir, final_repo_progress
|
|
|
|
|
|
def new_shard_downloader() -> ShardDownloader:
|
|
|
return SingletonShardDownloader(CachedShardDownloader(NewShardDownloader()))
|