|
@@ -0,0 +1,180 @@
|
|
|
+from exo.inference.shard import Shard
|
|
|
+from exo.models import get_repo
|
|
|
+from pathlib import Path
|
|
|
+from exo.download.hf.hf_helpers import get_hf_endpoint, get_auth_headers, filter_repo_objects, get_allow_patterns
|
|
|
+from exo.download.shard_download import ShardDownloader
|
|
|
+from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent
|
|
|
+from exo.helpers import AsyncCallbackSystem, DEBUG
|
|
|
+from exo.models import get_supported_models, build_base_shard
|
|
|
+import os
|
|
|
+import aiofiles.os as aios
|
|
|
+import aiohttp
|
|
|
+import aiofiles
|
|
|
+from urllib.parse import urljoin
|
|
|
+from typing import Optional, Callable, Union, Tuple, Dict
|
|
|
+import time
|
|
|
+from datetime import timedelta
|
|
|
+import asyncio
|
|
|
+import json
|
|
|
+
|
|
|
+def exo_home() -> Path:
|
|
|
+ return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo"))
|
|
|
+
|
|
|
+async def ensure_downloads_dir() -> Path:
|
|
|
+ downloads_dir = exo_home()/"downloads"
|
|
|
+ await aios.makedirs(downloads_dir, exist_ok=True)
|
|
|
+ return downloads_dir
|
|
|
+
|
|
|
+async def fetch_file_list(session, repo_id, revision, path=""):
|
|
|
+ api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
|
|
|
+ url = f"{api_url}/{path}" if path else api_url
|
|
|
+
|
|
|
+ headers = await get_auth_headers()
|
|
|
+ async with session.get(url, headers=headers) as response:
|
|
|
+ if response.status == 200:
|
|
|
+ data = await response.json()
|
|
|
+ files = []
|
|
|
+ for item in data:
|
|
|
+ if item["type"] == "file":
|
|
|
+ files.append({"path": item["path"], "size": item["size"]})
|
|
|
+ elif item["type"] == "directory":
|
|
|
+ subfiles = await fetch_file_list(session, repo_id, revision, item["path"])
|
|
|
+ files.extend(subfiles)
|
|
|
+ return files
|
|
|
+ else:
|
|
|
+ raise Exception(f"Failed to fetch file list: {response.status}")
|
|
|
+
|
|
|
+async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Optional[Callable[[int, int], None]] = None) -> Path:
|
|
|
+ if (target_dir/path).exists(): return target_dir/path
|
|
|
+ base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
|
|
|
+ url = urljoin(base_url, path)
|
|
|
+ headers = await get_auth_headers()
|
|
|
+ async with session.get(url, headers=headers) as r:
|
|
|
+ assert r.status == 200, r.status
|
|
|
+ length = int(r.headers.get('content-length', 0))
|
|
|
+ n_read = 0
|
|
|
+ async with aiofiles.tempfile.NamedTemporaryFile(dir=target_dir, delete=False) as temp_file:
|
|
|
+ while chunk := await r.content.read(16384): on_progress(n_read := n_read + await temp_file.write(chunk), length)
|
|
|
+ await aios.rename(temp_file.name, target_dir/path)
|
|
|
+ return target_dir/path
|
|
|
+
|
|
|
+def calculate_repo_progress(repo_id: str, revision: str, file_progress: Dict[str, RepoFileProgressEvent], all_start_time: float) -> RepoProgressEvent:
|
|
|
+ all_total_bytes = sum([p.total for p in file_progress.values()])
|
|
|
+ all_downloaded_bytes = sum([p.downloaded for p in file_progress.values()])
|
|
|
+ elapsed_time = time.time() - all_start_time
|
|
|
+ all_speed = all_downloaded_bytes / elapsed_time if elapsed_time > 0 else 0
|
|
|
+ all_eta = timedelta(seconds=(all_total_bytes - all_downloaded_bytes) / all_speed) if all_speed > 0 else timedelta(seconds=0)
|
|
|
+ status = "not_started" if all_downloaded_bytes == 0 else "complete" if all_downloaded_bytes == all_total_bytes else "in_progress"
|
|
|
+ return RepoProgressEvent(repo_id, revision, len([p for p in file_progress.values() if p.downloaded == p.total]), len(file_progress), all_downloaded_bytes, all_downloaded_bytes, all_total_bytes, all_speed, all_eta, file_progress, status)
|
|
|
+
|
|
|
+async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]:
|
|
|
+ target_dir = await ensure_downloads_dir()/repo_id.replace("/", "--")
|
|
|
+ async with aiohttp.ClientSession() as session:
|
|
|
+ index_file = await download_file(session, repo_id, revision, "model.safetensors.index.json", target_dir)
|
|
|
+ async with aiofiles.open(index_file, 'r') as f:
|
|
|
+ index_data = json.loads(await f.read())
|
|
|
+ return index_data.get("weight_map")
|
|
|
+
|
|
|
+async def resolve_allow_patterns(shard: Shard, inference_engine_classname: str) -> list[str]:
|
|
|
+ weight_map = await get_weight_map(get_repo(shard.model_id, inference_engine_classname))
|
|
|
+ return get_allow_patterns(weight_map, shard)
|
|
|
+
|
|
|
+async def download_shard(shard: Shard, inference_engine_classname: str, on_progress: AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]], max_parallel_downloads: int = 6, skip_download: bool = False) -> tuple[Path, RepoProgressEvent]:
|
|
|
+ if DEBUG >= 2 and not skip_download: print(f"Downloading {shard.model_id=} for {inference_engine_classname}")
|
|
|
+ repo_id = get_repo(shard.model_id, inference_engine_classname)
|
|
|
+ revision = "main"
|
|
|
+ target_dir = await ensure_downloads_dir()/repo_id.replace("/", "--")
|
|
|
+ if not skip_download: await aios.makedirs(target_dir, exist_ok=True)
|
|
|
+
|
|
|
+ if repo_id is None:
|
|
|
+ raise ValueError(f"No repo found for {shard.model_id=} and inference engine {inference_engine_classname}")
|
|
|
+
|
|
|
+ allow_patterns = await resolve_allow_patterns(shard, inference_engine_classname) if not skip_download else None
|
|
|
+ if DEBUG >= 3: print(f"Downloading {shard.model_id=} with {allow_patterns=}")
|
|
|
+
|
|
|
+ all_start_time = time.time()
|
|
|
+ async with aiohttp.ClientSession() as session:
|
|
|
+ file_list = await fetch_file_list(session, repo_id, revision)
|
|
|
+ filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, key=lambda x: x["path"]))
|
|
|
+ file_progress: Dict[str, RepoFileProgressEvent] = {}
|
|
|
+ def on_progress_wrapper(file: dict, curr_bytes: int, total_bytes: int):
|
|
|
+ start_time = file_progress[file["path"]].start_time if file["path"] in file_progress else time.time()
|
|
|
+ speed = curr_bytes / (time.time() - start_time)
|
|
|
+ eta = timedelta(seconds=(total_bytes - curr_bytes) / speed)
|
|
|
+ file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], curr_bytes, curr_bytes, total_bytes, speed, eta, "in_progress", start_time)
|
|
|
+ on_progress.trigger_all(shard, calculate_repo_progress(repo_id, revision, file_progress, all_start_time))
|
|
|
+ if DEBUG >= 2: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}")
|
|
|
+ for file in filtered_file_list:
|
|
|
+ downloaded_bytes = (await aios.stat(target_dir/file["path"])).st_size if await aios.path.exists(target_dir/file["path"]) else 0
|
|
|
+ file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], downloaded_bytes, 0, file["size"], 0, timedelta(0), "not_started" if downloaded_bytes == 0 else "complete" if downloaded_bytes == file["size"] else "in_progress", time.time())
|
|
|
+
|
|
|
+ semaphore = asyncio.Semaphore(max_parallel_downloads)
|
|
|
+ async def download_with_semaphore(file):
|
|
|
+ async with semaphore:
|
|
|
+ await download_file(session, repo_id, revision, file["path"], target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes))
|
|
|
+ if not skip_download: await asyncio.gather(*[download_with_semaphore(file) for file in filtered_file_list])
|
|
|
+ final_repo_progress = calculate_repo_progress(repo_id, revision, file_progress, all_start_time)
|
|
|
+ on_progress.trigger_all(shard, final_repo_progress)
|
|
|
+ return target_dir, final_repo_progress
|
|
|
+
|
|
|
+def new_shard_downloader() -> ShardDownloader:
|
|
|
+ return SingletonShardDownloader(CachedShardDownloader(NewShardDownloader()))
|
|
|
+
|
|
|
+class SingletonShardDownloader(ShardDownloader):
|
|
|
+ def __init__(self, shard_downloader: ShardDownloader):
|
|
|
+ self.shard_downloader = shard_downloader
|
|
|
+ self.active_downloads: Dict[Shard, asyncio.Task] = {}
|
|
|
+
|
|
|
+ @property
|
|
|
+ def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
|
|
|
+ return self.shard_downloader.on_progress
|
|
|
+
|
|
|
+ async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
|
|
|
+ if shard not in self.active_downloads: self.active_downloads[shard] = asyncio.create_task(self.shard_downloader.ensure_shard(shard, inference_engine_name))
|
|
|
+ try: return await self.active_downloads[shard]
|
|
|
+ finally:
|
|
|
+ if shard in self.active_downloads and self.active_downloads[shard].done(): del self.active_downloads[shard]
|
|
|
+
|
|
|
+ async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
|
|
|
+ return await self.shard_downloader.get_shard_download_status(inference_engine_name)
|
|
|
+
|
|
|
+class CachedShardDownloader(ShardDownloader):
|
|
|
+ def __init__(self, shard_downloader: ShardDownloader):
|
|
|
+ self.shard_downloader = shard_downloader
|
|
|
+ self.cache: Dict[tuple[str, Shard], Path] = {}
|
|
|
+
|
|
|
+ @property
|
|
|
+ def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
|
|
|
+ return self.shard_downloader.on_progress
|
|
|
+
|
|
|
+ async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
|
|
|
+ if (inference_engine_name, shard) in self.cache:
|
|
|
+ if DEBUG >= 2: print(f"ensure_shard cache hit {shard=} for {inference_engine_name}")
|
|
|
+ return self.cache[(inference_engine_name, shard)]
|
|
|
+ if DEBUG >= 2: print(f"ensure_shard cache miss {shard=} for {inference_engine_name}")
|
|
|
+ target_dir = await self.shard_downloader.ensure_shard(shard, inference_engine_name)
|
|
|
+ self.cache[(inference_engine_name, shard)] = target_dir
|
|
|
+ return target_dir
|
|
|
+
|
|
|
+ async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
|
|
|
+ return await self.shard_downloader.get_shard_download_status(inference_engine_name)
|
|
|
+
|
|
|
+class NewShardDownloader(ShardDownloader):
|
|
|
+ def __init__(self):
|
|
|
+ self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
|
|
|
+
|
|
|
+ @property
|
|
|
+ def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
|
|
|
+ return self._on_progress
|
|
|
+
|
|
|
+ async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
|
|
|
+ target_dir, _ = await download_shard(shard, inference_engine_name, self.on_progress)
|
|
|
+ return target_dir
|
|
|
+
|
|
|
+ async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
|
|
|
+ if DEBUG >= 2: print("Getting shard download status for", inference_engine_name)
|
|
|
+ downloads = await asyncio.gather(*[download_shard(build_base_shard(model_id, inference_engine_name), inference_engine_name, self.on_progress, skip_download=True) for model_id in get_supported_models([[inference_engine_name]])], return_exceptions=True)
|
|
|
+ if DEBUG >= 6: print("Downloaded shards:", downloads)
|
|
|
+ if any(isinstance(d, Exception) for d in downloads) and DEBUG >= 1: print("Error downloading shards:", [d for d in downloads if isinstance(d, Exception)])
|
|
|
+ return [d for d in downloads if not isinstance(d, Exception)]
|
|
|
+
|