|
@@ -20,7 +20,6 @@ import traceback
|
|
import shutil
|
|
import shutil
|
|
import tempfile
|
|
import tempfile
|
|
import hashlib
|
|
import hashlib
|
|
-from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_not_exception_type
|
|
|
|
|
|
|
|
def exo_home() -> Path:
|
|
def exo_home() -> Path:
|
|
return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo"))
|
|
return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo"))
|
|
@@ -74,13 +73,20 @@ async def fetch_file_list_with_cache(repo_id: str, revision: str = "main") -> Li
|
|
cache_file = (await ensure_exo_tmp())/f"{repo_id.replace('/', '--')}--{revision}--file_list.json"
|
|
cache_file = (await ensure_exo_tmp())/f"{repo_id.replace('/', '--')}--{revision}--file_list.json"
|
|
if await aios.path.exists(cache_file):
|
|
if await aios.path.exists(cache_file):
|
|
async with aiofiles.open(cache_file, 'r') as f: return json.loads(await f.read())
|
|
async with aiofiles.open(cache_file, 'r') as f: return json.loads(await f.read())
|
|
- file_list = await fetch_file_list(repo_id, revision)
|
|
|
|
|
|
+ file_list = await fetch_file_list_with_retry(repo_id, revision)
|
|
await aios.makedirs(cache_file.parent, exist_ok=True)
|
|
await aios.makedirs(cache_file.parent, exist_ok=True)
|
|
async with aiofiles.open(cache_file, 'w') as f: await f.write(json.dumps(file_list))
|
|
async with aiofiles.open(cache_file, 'w') as f: await f.write(json.dumps(file_list))
|
|
return file_list
|
|
return file_list
|
|
|
|
|
|
-@retry(stop=stop_after_attempt(30), wait=wait_fixed(1))
|
|
|
|
-async def fetch_file_list(repo_id: str, revision: str = "main", path: str = "") -> List[Dict[str, Union[str, int]]]:
|
|
|
|
|
|
+async def fetch_file_list_with_retry(repo_id: str, revision: str = "main", path: str = "") -> List[Dict[str, Union[str, int]]]:
|
|
|
|
+ n_attempts = 30
|
|
|
|
+ for attempt in range(n_attempts):
|
|
|
|
+ try: return await _fetch_file_list(repo_id, revision, path)
|
|
|
|
+ except Exception as e:
|
|
|
|
+ if attempt == n_attempts - 1: raise e
|
|
|
|
+ await asyncio.sleep(min(8, 0.1 * (2 ** attempt)))
|
|
|
|
+
|
|
|
|
+async def _fetch_file_list(repo_id: str, revision: str = "main", path: str = "") -> List[Dict[str, Union[str, int]]]:
|
|
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
|
|
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
|
|
url = f"{api_url}/{path}" if path else api_url
|
|
url = f"{api_url}/{path}" if path else api_url
|
|
|
|
|
|
@@ -94,7 +100,7 @@ async def fetch_file_list(repo_id: str, revision: str = "main", path: str = "")
|
|
if item["type"] == "file":
|
|
if item["type"] == "file":
|
|
files.append({"path": item["path"], "size": item["size"]})
|
|
files.append({"path": item["path"], "size": item["size"]})
|
|
elif item["type"] == "directory":
|
|
elif item["type"] == "directory":
|
|
- subfiles = await fetch_file_list(repo_id, revision, item["path"])
|
|
|
|
|
|
+ subfiles = await _fetch_file_list(repo_id, revision, item["path"])
|
|
files.extend(subfiles)
|
|
files.extend(subfiles)
|
|
return files
|
|
return files
|
|
else:
|
|
else:
|
|
@@ -122,8 +128,15 @@ async def file_meta(repo_id: str, revision: str, path: str) -> Tuple[int, str]:
|
|
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"): etag = etag[1:-1]
|
|
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"): etag = etag[1:-1]
|
|
return content_length, etag
|
|
return content_length, etag
|
|
|
|
|
|
-@retry(stop=stop_after_attempt(30), wait=wait_fixed(1), retry=retry_if_not_exception_type(FileNotFoundError))
|
|
|
|
-async def download_file(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
|
|
|
|
|
|
+async def download_file_with_retry(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
|
|
|
|
+ n_attempts = 30
|
|
|
|
+ for attempt in range(n_attempts):
|
|
|
|
+ try: return await _download_file(repo_id, revision, path, target_dir, on_progress)
|
|
|
|
+ except Exception as e:
|
|
|
|
+ if isinstance(e, FileNotFoundError) or attempt == n_attempts - 1: raise e
|
|
|
|
+ await asyncio.sleep(min(8, 0.1 * (2 ** attempt)))
|
|
|
|
+
|
|
|
|
+async def _download_file(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
|
|
if await aios.path.exists(target_dir/path): return target_dir/path
|
|
if await aios.path.exists(target_dir/path): return target_dir/path
|
|
await aios.makedirs((target_dir/path).parent, exist_ok=True)
|
|
await aios.makedirs((target_dir/path).parent, exist_ok=True)
|
|
length, remote_hash = await file_meta(repo_id, revision, path)
|
|
length, remote_hash = await file_meta(repo_id, revision, path)
|
|
@@ -163,7 +176,7 @@ def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_prog
|
|
|
|
|
|
async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]:
|
|
async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]:
|
|
target_dir = (await ensure_exo_tmp())/repo_id.replace("/", "--")
|
|
target_dir = (await ensure_exo_tmp())/repo_id.replace("/", "--")
|
|
- index_file = await download_file(repo_id, revision, "model.safetensors.index.json", target_dir)
|
|
|
|
|
|
+ index_file = await download_file_with_retry(repo_id, revision, "model.safetensors.index.json", target_dir)
|
|
async with aiofiles.open(index_file, 'r') as f: index_data = json.loads(await f.read())
|
|
async with aiofiles.open(index_file, 'r') as f: index_data = json.loads(await f.read())
|
|
return index_data.get("weight_map")
|
|
return index_data.get("weight_map")
|
|
|
|
|
|
@@ -214,7 +227,7 @@ async def download_shard(shard: Shard, inference_engine_classname: str, on_progr
|
|
semaphore = asyncio.Semaphore(max_parallel_downloads)
|
|
semaphore = asyncio.Semaphore(max_parallel_downloads)
|
|
async def download_with_semaphore(file):
|
|
async def download_with_semaphore(file):
|
|
async with semaphore:
|
|
async with semaphore:
|
|
- await download_file(repo_id, revision, file["path"], target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes))
|
|
|
|
|
|
+ await download_file_with_retry(repo_id, revision, file["path"], target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes))
|
|
if not skip_download: await asyncio.gather(*[download_with_semaphore(file) for file in filtered_file_list])
|
|
if not skip_download: await asyncio.gather(*[download_with_semaphore(file) for file in filtered_file_list])
|
|
final_repo_progress = calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time)
|
|
final_repo_progress = calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time)
|
|
on_progress.trigger_all(shard, final_repo_progress)
|
|
on_progress.trigger_all(shard, final_repo_progress)
|