|
@@ -105,7 +105,7 @@ def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_prog
|
|
|
elapsed_time = time.time() - all_start_time
|
|
|
all_speed = all_downloaded_bytes_this_session / elapsed_time if elapsed_time > 0 else 0
|
|
|
all_eta = timedelta(seconds=(all_total_bytes - all_downloaded_bytes) / all_speed) if all_speed > 0 else timedelta(seconds=0)
|
|
|
- status = "not_started" if all_downloaded_bytes == 0 else "complete" if all_downloaded_bytes == all_total_bytes else "in_progress"
|
|
|
+ status = "complete" if all(p.status == "complete" for p in file_progress.values()) else "not_started" if all(p.status == "not_started" for p in file_progress.values()) else "in_progress"
|
|
|
return RepoProgressEvent(shard, repo_id, revision, len([p for p in file_progress.values() if p.downloaded == p.total]), len(file_progress), all_downloaded_bytes, all_downloaded_bytes_this_session, all_total_bytes, all_speed, all_eta, file_progress, status)
|
|
|
|
|
|
async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]:
|
|
@@ -147,12 +147,12 @@ async def download_shard(shard: Shard, inference_engine_classname: str, on_progr
|
|
|
downloaded_this_session = file_progress[file["path"]].downloaded_this_session + (curr_bytes - file_progress[file["path"]].downloaded) if file["path"] in file_progress else curr_bytes
|
|
|
speed = downloaded_this_session / (time.time() - start_time)
|
|
|
eta = timedelta(seconds=(total_bytes - curr_bytes) / speed)
|
|
|
- file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], curr_bytes, downloaded_this_session, total_bytes, speed, eta, "in_progress", start_time)
|
|
|
+ file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], curr_bytes, downloaded_this_session, total_bytes, speed, eta, "complete" if curr_bytes == total_bytes else "in_progress", start_time)
|
|
|
on_progress.trigger_all(shard, calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time))
|
|
|
if DEBUG >= 6: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}")
|
|
|
for file in filtered_file_list:
|
|
|
downloaded_bytes = (await aios.stat(target_dir/file["path"])).st_size if await aios.path.exists(target_dir/file["path"]) else 0
|
|
|
- file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], downloaded_bytes, 0, file["size"], 0, timedelta(0), "not_started" if downloaded_bytes == 0 else "complete" if downloaded_bytes == file["size"] else "in_progress", time.time())
|
|
|
+ file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], downloaded_bytes, 0, file["size"], 0, timedelta(0), "complete" if downloaded_bytes == file["size"] else "not_started", time.time())
|
|
|
|
|
|
semaphore = asyncio.Semaphore(max_parallel_downloads)
|
|
|
async def download_with_semaphore(file):
|