Prechádzať zdrojové kódy

retry fetch_file_list also

Alex Cheema 3 mesiacov pred
rodič
commit
788c49784c
1 zmenil súbory, kde vykonal 3 pridanie a 2 odobranie
  1. 3 2
      exo/download/new_shard_download.py

+ 3 - 2
exo/download/new_shard_download.py

@@ -69,6 +69,7 @@ async def seed_models(seed_dir: Union[str, Path]):
           print(f"Error seeding model {path} to {dest_path}")
           traceback.print_exc()
 
+@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=0.5))
 async def fetch_file_list(repo_id, revision, path=""):
   api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
   url = f"{api_url}/{path}" if path else api_url
@@ -151,8 +152,8 @@ async def download_shard(shard: Shard, inference_engine_classname: str, on_progr
   def on_progress_wrapper(file: dict, curr_bytes: int, total_bytes: int):
     start_time = file_progress[file["path"]].start_time if file["path"] in file_progress else time.time()
     downloaded_this_session = file_progress[file["path"]].downloaded_this_session + (curr_bytes - file_progress[file["path"]].downloaded) if file["path"] in file_progress else curr_bytes
-    speed = downloaded_this_session / (time.time() - start_time)
-    eta = timedelta(seconds=(total_bytes - curr_bytes) / speed)
+    speed = downloaded_this_session / (time.time() - start_time) if time.time() - start_time > 0 else 0
+    eta = timedelta(seconds=(total_bytes - curr_bytes) / speed) if speed > 0 else timedelta(seconds=0)
     file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], curr_bytes, downloaded_this_session, total_bytes, speed, eta, "complete" if curr_bytes == total_bytes else "in_progress", start_time)
     on_progress.trigger_all(shard, calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time))
     if DEBUG >= 6: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}")