Browse Source

remove tenacity dependency, implement simple retry logic instead

Alex Cheema 4 months ago
parent
commit
5157d80a46
2 changed files with 22 additions and 10 deletions
  1. 22 9
      exo/download/new_shard_download.py
  2. 0 1
      setup.py

+ 22 - 9
exo/download/new_shard_download.py

@@ -20,7 +20,6 @@ import traceback
 import shutil
 import tempfile
 import hashlib
-from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_not_exception_type
 
 def exo_home() -> Path:
   return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo"))
@@ -74,13 +73,20 @@ async def fetch_file_list_with_cache(repo_id: str, revision: str = "main") -> Li
   cache_file = (await ensure_exo_tmp())/f"{repo_id.replace('/', '--')}--{revision}--file_list.json"
   if await aios.path.exists(cache_file):
     async with aiofiles.open(cache_file, 'r') as f: return json.loads(await f.read())
-  file_list = await fetch_file_list(repo_id, revision)
+  file_list = await fetch_file_list_with_retry(repo_id, revision)
   await aios.makedirs(cache_file.parent, exist_ok=True)
   async with aiofiles.open(cache_file, 'w') as f: await f.write(json.dumps(file_list))
   return file_list
 
-@retry(stop=stop_after_attempt(30), wait=wait_fixed(1))
-async def fetch_file_list(repo_id: str, revision: str = "main", path: str = "") -> List[Dict[str, Union[str, int]]]:
+async def fetch_file_list_with_retry(repo_id: str, revision: str = "main", path: str = "") -> List[Dict[str, Union[str, int]]]:
+  n_attempts = 30
+  for attempt in range(n_attempts):
+    try: return await _fetch_file_list(repo_id, revision, path)
+    except Exception as e:
+      if attempt == n_attempts - 1: raise e
+      await asyncio.sleep(min(8, 0.1 * (2 ** attempt)))
+
+async def _fetch_file_list(repo_id: str, revision: str = "main", path: str = "") -> List[Dict[str, Union[str, int]]]:
   api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
   url = f"{api_url}/{path}" if path else api_url
 
@@ -94,7 +100,7 @@ async def fetch_file_list(repo_id: str, revision: str = "main", path: str = "")
           if item["type"] == "file":
             files.append({"path": item["path"], "size": item["size"]})
           elif item["type"] == "directory":
-            subfiles = await fetch_file_list(repo_id, revision, item["path"])
+            subfiles = await _fetch_file_list(repo_id, revision, item["path"])
             files.extend(subfiles)
         return files
       else:
@@ -122,8 +128,15 @@ async def file_meta(repo_id: str, revision: str, path: str) -> Tuple[int, str]:
       if  (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"): etag = etag[1:-1]
       return content_length, etag
 
-@retry(stop=stop_after_attempt(30), wait=wait_fixed(1), retry=retry_if_not_exception_type(FileNotFoundError))
-async def download_file(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
+async def download_file_with_retry(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
+  n_attempts = 30
+  for attempt in range(n_attempts):
+    try: return await _download_file(repo_id, revision, path, target_dir, on_progress)
+    except Exception as e:
+      if isinstance(e, FileNotFoundError) or attempt == n_attempts - 1: raise e
+      await asyncio.sleep(min(8, 0.1 * (2 ** attempt)))
+
+async def _download_file(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
   if await aios.path.exists(target_dir/path): return target_dir/path
   await aios.makedirs((target_dir/path).parent, exist_ok=True)
   length, remote_hash = await file_meta(repo_id, revision, path)
@@ -163,7 +176,7 @@ def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_prog
 
 async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]:
   target_dir = (await ensure_exo_tmp())/repo_id.replace("/", "--")
-  index_file = await download_file(repo_id, revision, "model.safetensors.index.json", target_dir)
+  index_file = await download_file_with_retry(repo_id, revision, "model.safetensors.index.json", target_dir)
   async with aiofiles.open(index_file, 'r') as f: index_data = json.loads(await f.read())
   return index_data.get("weight_map")
 
@@ -214,7 +227,7 @@ async def download_shard(shard: Shard, inference_engine_classname: str, on_progr
   semaphore = asyncio.Semaphore(max_parallel_downloads)
   async def download_with_semaphore(file):
     async with semaphore:
-      await download_file(repo_id, revision, file["path"], target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes))
+      await download_file_with_retry(repo_id, revision, file["path"], target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes))
   if not skip_download: await asyncio.gather(*[download_with_semaphore(file) for file in filtered_file_list])
   final_repo_progress = calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time)
   on_progress.trigger_all(shard, final_repo_progress)

+ 0 - 1
setup.py

@@ -24,7 +24,6 @@ install_requires = [
   "requests==2.32.3",
   "rich==13.7.1",
   "scapy==2.6.1",
-  "tenacity==9.0.0",
   "tqdm==4.66.4",
   "transformers==4.46.3",
   "uuid==1.30",