|
@@ -69,6 +69,7 @@ async def seed_models(seed_dir: Union[str, Path]):
|
|
|
print(f"Error seeding model {path} to {dest_path}")
|
|
|
traceback.print_exc()
|
|
|
|
|
|
+@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=0.5))
|
|
|
async def fetch_file_list(repo_id, revision, path=""):
|
|
|
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
|
|
|
url = f"{api_url}/{path}" if path else api_url
|
|
@@ -151,8 +152,8 @@ async def download_shard(shard: Shard, inference_engine_classname: str, on_progr
|
|
|
def on_progress_wrapper(file: dict, curr_bytes: int, total_bytes: int):
|
|
|
start_time = file_progress[file["path"]].start_time if file["path"] in file_progress else time.time()
|
|
|
downloaded_this_session = file_progress[file["path"]].downloaded_this_session + (curr_bytes - file_progress[file["path"]].downloaded) if file["path"] in file_progress else curr_bytes
|
|
|
- speed = downloaded_this_session / (time.time() - start_time)
|
|
|
- eta = timedelta(seconds=(total_bytes - curr_bytes) / speed)
|
|
|
+ speed = downloaded_this_session / (time.time() - start_time) if time.time() - start_time > 0 else 0
|
|
|
+ eta = timedelta(seconds=(total_bytes - curr_bytes) / speed) if speed > 0 else timedelta(seconds=0)
|
|
|
file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], curr_bytes, downloaded_this_session, total_bytes, speed, eta, "complete" if curr_bytes == total_bytes else "in_progress", start_time)
|
|
|
on_progress.trigger_all(shard, calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time))
|
|
|
if DEBUG >= 6: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}")
|