|
@@ -11,7 +11,7 @@ import aiofiles.os as aios
|
|
|
import aiohttp
|
|
|
import aiofiles
|
|
|
from urllib.parse import urljoin
|
|
|
-from typing import Callable, Union, Tuple, Dict, List, Optional, Literal
|
|
|
+from typing import Callable, Union, Tuple, Dict, List, Optional, Literal, AsyncIterator
|
|
|
import time
|
|
|
from datetime import timedelta
|
|
|
import asyncio
|
|
@@ -20,7 +20,7 @@ import traceback
|
|
|
import shutil
|
|
|
import tempfile
|
|
|
import hashlib
|
|
|
-from tenacity import retry, stop_after_attempt, wait_fixed
|
|
|
+from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_not_exception_type
|
|
|
|
|
|
def exo_home() -> Path:
|
|
|
return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo"))
|
|
@@ -70,8 +70,17 @@ async def seed_models(seed_dir: Union[str, Path]):
|
|
|
print(f"Error seeding model {path} to {dest_path}")
|
|
|
traceback.print_exc()
|
|
|
|
|
|
+async def fetch_file_list_with_cache(repo_id: str, revision: str = "main") -> List[Dict[str, Union[str, int]]]:
|
|
|
+ cache_file = (await ensure_exo_tmp())/f"{repo_id.replace('/', '--')}--{revision}--file_list.json"
|
|
|
+ if await aios.path.exists(cache_file):
|
|
|
+ async with aiofiles.open(cache_file, 'r') as f: return json.loads(await f.read())
|
|
|
+ file_list = await fetch_file_list(repo_id, revision)
|
|
|
+ await aios.makedirs(cache_file.parent, exist_ok=True)
|
|
|
+ async with aiofiles.open(cache_file, 'w') as f: await f.write(json.dumps(file_list))
|
|
|
+ return file_list
|
|
|
+
|
|
|
@retry(stop=stop_after_attempt(30), wait=wait_fixed(1))
|
|
|
-async def fetch_file_list(repo_id, revision, path=""):
|
|
|
+async def fetch_file_list(repo_id: str, revision: str = "main", path: str = "") -> List[Dict[str, Union[str, int]]]:
|
|
|
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
|
|
|
url = f"{api_url}/{path}" if path else api_url
|
|
|
|
|
@@ -106,14 +115,14 @@ async def file_meta(repo_id: str, revision: str, path: str) -> Tuple[int, str]:
|
|
|
headers = await get_auth_headers()
|
|
|
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session:
|
|
|
async with session.head(url, headers=headers) as r:
|
|
|
- content_length = int(r.headers.get('content-length', 0))
|
|
|
+ content_length = int(r.headers.get('x-linked-size') or r.headers.get('content-length') or 0)
|
|
|
etag = r.headers.get('X-Linked-ETag') or r.headers.get('ETag') or r.headers.get('Etag')
|
|
|
assert content_length > 0, f"No content length for {url}"
|
|
|
assert etag is not None, f"No remote hash for {url}"
|
|
|
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"): etag = etag[1:-1]
|
|
|
return content_length, etag
|
|
|
|
|
|
-@retry(stop=stop_after_attempt(30), wait=wait_fixed(1))
|
|
|
+@retry(stop=stop_after_attempt(30), wait=wait_fixed(1), retry=retry_if_not_exception_type(FileNotFoundError))
|
|
|
async def download_file(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
|
|
|
if await aios.path.exists(target_dir/path): return target_dir/path
|
|
|
await aios.makedirs((target_dir/path).parent, exist_ok=True)
|
|
@@ -124,9 +133,10 @@ async def download_file(repo_id: str, revision: str, path: str, target_dir: Path
|
|
|
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
|
|
|
headers = await get_auth_headers()
|
|
|
if resume_byte_pos: headers['Range'] = f'bytes={resume_byte_pos}-'
|
|
|
- n_read = 0
|
|
|
+ n_read = resume_byte_pos or 0
|
|
|
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session:
|
|
|
async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as r:
|
|
|
+ if r.status == 404: raise FileNotFoundError(f"File not found: {url}")
|
|
|
assert r.status in [200, 206], f"Failed to download {path} from {url}: {r.status}"
|
|
|
async with aiofiles.open(partial_path, 'ab' if resume_byte_pos else 'wb') as f:
|
|
|
while chunk := await r.content.read(1024 * 1024): on_progress(n_read := n_read + await f.write(chunk), length)
|
|
@@ -152,7 +162,7 @@ def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_prog
|
|
|
return RepoProgressEvent(shard, repo_id, revision, len([p for p in file_progress.values() if p.downloaded == p.total]), len(file_progress), all_downloaded_bytes, all_downloaded_bytes_this_session, all_total_bytes, all_speed, all_eta, file_progress, status)
|
|
|
|
|
|
async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]:
|
|
|
- target_dir = await ensure_exo_tmp()/repo_id.replace("/", "--")
|
|
|
+ target_dir = (await ensure_exo_tmp())/repo_id.replace("/", "--")
|
|
|
index_file = await download_file(repo_id, revision, "model.safetensors.index.json", target_dir)
|
|
|
async with aiofiles.open(index_file, 'r') as f: index_data = json.loads(await f.read())
|
|
|
return index_data.get("weight_map")
|
|
@@ -166,6 +176,12 @@ async def resolve_allow_patterns(shard: Shard, inference_engine_classname: str)
|
|
|
if DEBUG >= 1: traceback.print_exc()
|
|
|
return ["*"]
|
|
|
|
|
|
+async def get_downloaded_size(path: Path) -> int:
|
|
|
+ partial_path = path.with_suffix(path.suffix + ".partial")
|
|
|
+ if await aios.path.exists(path): return (await aios.stat(path)).st_size
|
|
|
+ if await aios.path.exists(partial_path): return (await aios.stat(partial_path)).st_size
|
|
|
+ return 0
|
|
|
+
|
|
|
async def download_shard(shard: Shard, inference_engine_classname: str, on_progress: AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]], max_parallel_downloads: int = 6, skip_download: bool = False) -> tuple[Path, RepoProgressEvent]:
|
|
|
if DEBUG >= 2 and not skip_download: print(f"Downloading {shard.model_id=} for {inference_engine_classname}")
|
|
|
repo_id = get_repo(shard.model_id, inference_engine_classname)
|
|
@@ -180,7 +196,7 @@ async def download_shard(shard: Shard, inference_engine_classname: str, on_progr
|
|
|
if DEBUG >= 2: print(f"Downloading {shard.model_id=} with {allow_patterns=}")
|
|
|
|
|
|
all_start_time = time.time()
|
|
|
- file_list = await fetch_file_list(repo_id, revision)
|
|
|
+ file_list = await fetch_file_list_with_cache(repo_id, revision)
|
|
|
filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, key=lambda x: x["path"]))
|
|
|
file_progress: Dict[str, RepoFileProgressEvent] = {}
|
|
|
def on_progress_wrapper(file: dict, curr_bytes: int, total_bytes: int):
|
|
@@ -192,7 +208,7 @@ async def download_shard(shard: Shard, inference_engine_classname: str, on_progr
|
|
|
on_progress.trigger_all(shard, calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time))
|
|
|
if DEBUG >= 6: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}")
|
|
|
for file in filtered_file_list:
|
|
|
- downloaded_bytes = (await aios.stat(target_dir/file["path"])).st_size if await aios.path.exists(target_dir/file["path"]) else 0
|
|
|
+ downloaded_bytes = await get_downloaded_size(target_dir/file["path"])
|
|
|
file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], downloaded_bytes, 0, file["size"], 0, timedelta(0), "complete" if downloaded_bytes == file["size"] else "not_started", time.time())
|
|
|
|
|
|
semaphore = asyncio.Semaphore(max_parallel_downloads)
|
|
@@ -225,8 +241,9 @@ class SingletonShardDownloader(ShardDownloader):
|
|
|
finally:
|
|
|
if shard in self.active_downloads and self.active_downloads[shard].done(): del self.active_downloads[shard]
|
|
|
|
|
|
- async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
|
|
|
- return await self.shard_downloader.get_shard_download_status(inference_engine_name)
|
|
|
+ async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
|
|
|
+ async for path, status in self.shard_downloader.get_shard_download_status(inference_engine_name):
|
|
|
+ yield path, status
|
|
|
|
|
|
class CachedShardDownloader(ShardDownloader):
|
|
|
def __init__(self, shard_downloader: ShardDownloader):
|
|
@@ -246,8 +263,9 @@ class CachedShardDownloader(ShardDownloader):
|
|
|
self.cache[(inference_engine_name, shard)] = target_dir
|
|
|
return target_dir
|
|
|
|
|
|
- async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
|
|
|
- return await self.shard_downloader.get_shard_download_status(inference_engine_name)
|
|
|
+ async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
|
|
|
+ async for path, status in self.shard_downloader.get_shard_download_status(inference_engine_name):
|
|
|
+ yield path, status
|
|
|
|
|
|
class NewShardDownloader(ShardDownloader):
|
|
|
def __init__(self):
|
|
@@ -261,9 +279,12 @@ class NewShardDownloader(ShardDownloader):
|
|
|
target_dir, _ = await download_shard(shard, inference_engine_name, self.on_progress)
|
|
|
return target_dir
|
|
|
|
|
|
- async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
|
|
|
+ async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
|
|
|
if DEBUG >= 2: print("Getting shard download status for", inference_engine_name)
|
|
|
- downloads = await asyncio.gather(*[download_shard(build_full_shard(model_id, inference_engine_name), inference_engine_name, self.on_progress, skip_download=True) for model_id in get_supported_models([[inference_engine_name]])], return_exceptions=True)
|
|
|
- if DEBUG >= 6: print("Downloaded shards:", downloads)
|
|
|
- if any(isinstance(d, Exception) for d in downloads) and DEBUG >= 1: print("Error downloading shards:", [d for d in downloads if isinstance(d, Exception)])
|
|
|
- return [d for d in downloads if not isinstance(d, Exception)]
|
|
|
+ tasks = [download_shard(build_full_shard(model_id, inference_engine_name), inference_engine_name, self.on_progress, skip_download=True) for model_id in get_supported_models([[inference_engine_name]])]
|
|
|
+ for task in asyncio.as_completed(tasks):
|
|
|
+ try:
|
|
|
+ path, progress = await task
|
|
|
+ yield (path, progress)
|
|
|
+ except Exception as e:
|
|
|
+ print("Error downloading shard:", e)
|