|
@@ -19,6 +19,7 @@ import json
|
|
|
import traceback
|
|
|
import shutil
|
|
|
import tempfile
|
|
|
+from tenacity import retry, stop_after_attempt, wait_exponential
|
|
|
|
|
|
def exo_home() -> Path:
|
|
|
return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo"))
|
|
@@ -26,6 +27,10 @@ def exo_home() -> Path:
|
|
|
def exo_tmp() -> Path:
|
|
|
return Path(tempfile.gettempdir())/"exo"
|
|
|
|
|
|
+async def ensure_exo_home() -> Path:
|
|
|
+ await aios.makedirs(exo_home(), exist_ok=True)
|
|
|
+ return exo_home()
|
|
|
+
|
|
|
async def ensure_exo_tmp() -> Path:
|
|
|
await aios.makedirs(exo_tmp(), exist_ok=True)
|
|
|
return exo_tmp()
|
|
@@ -64,39 +69,42 @@ async def seed_models(seed_dir: Union[str, Path]):
|
|
|
print(f"Error seeding model {path} to {dest_path}")
|
|
|
traceback.print_exc()
|
|
|
|
|
|
-async def fetch_file_list(session, repo_id, revision, path=""):
|
|
|
+async def fetch_file_list(repo_id, revision, path=""):
|
|
|
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
|
|
|
url = f"{api_url}/{path}" if path else api_url
|
|
|
|
|
|
headers = await get_auth_headers()
|
|
|
- async with session.get(url, headers=headers) as response:
|
|
|
- if response.status == 200:
|
|
|
- data = await response.json()
|
|
|
- files = []
|
|
|
- for item in data:
|
|
|
- if item["type"] == "file":
|
|
|
- files.append({"path": item["path"], "size": item["size"]})
|
|
|
- elif item["type"] == "directory":
|
|
|
- subfiles = await fetch_file_list(session, repo_id, revision, item["path"])
|
|
|
- files.extend(subfiles)
|
|
|
- return files
|
|
|
- else:
|
|
|
- raise Exception(f"Failed to fetch file list: {response.status}")
|
|
|
+ async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30, connect=10, sock_read=1800, sock_connect=60)) as session:
|
|
|
+ async with session.get(url, headers=headers) as response:
|
|
|
+ if response.status == 200:
|
|
|
+ data = await response.json()
|
|
|
+ files = []
|
|
|
+ for item in data:
|
|
|
+ if item["type"] == "file":
|
|
|
+ files.append({"path": item["path"], "size": item["size"]})
|
|
|
+ elif item["type"] == "directory":
|
|
|
+ subfiles = await fetch_file_list(session, repo_id, revision, item["path"])
|
|
|
+ files.extend(subfiles)
|
|
|
+ return files
|
|
|
+ else:
|
|
|
+ raise Exception(f"Failed to fetch file list: {response.status}")
|
|
|
|
|
|
-async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
|
|
|
+@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=0.5))
|
|
|
+async def download_file(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
|
|
|
if (target_dir/path).exists(): return target_dir/path
|
|
|
await aios.makedirs((target_dir/path).parent, exist_ok=True)
|
|
|
base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
|
|
|
url = urljoin(base_url, path)
|
|
|
headers = await get_auth_headers()
|
|
|
- async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as r:
|
|
|
- assert r.status == 200, f"Failed to download {path} from {url}: {r.status}"
|
|
|
- length = int(r.headers.get('content-length', 0))
|
|
|
- n_read = 0
|
|
|
- async with aiofiles.tempfile.NamedTemporaryFile(dir=target_dir, delete=False) as temp_file:
|
|
|
- while chunk := await r.content.read(1024 * 1024): on_progress(n_read := n_read + await temp_file.write(chunk), length)
|
|
|
- await aios.rename(temp_file.name, target_dir/path)
|
|
|
- return target_dir/path
|
|
|
+ async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session:
|
|
|
+ async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as r:
|
|
|
+ assert r.status == 200, f"Failed to download {path} from {url}: {r.status}"
|
|
|
+ length = int(r.headers.get('content-length', 0))
|
|
|
+ n_read = 0
|
|
|
+ async with aiofiles.tempfile.NamedTemporaryFile(dir=target_dir, delete=False) as temp_file:
|
|
|
+ while chunk := await r.content.read(1024 * 1024): on_progress(n_read := n_read + await temp_file.write(chunk), length)
|
|
|
+ await aios.rename(temp_file.name, target_dir/path)
|
|
|
+ return target_dir/path
|
|
|
|
|
|
def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_progress: Dict[str, RepoFileProgressEvent], all_start_time: float) -> RepoProgressEvent:
|
|
|
all_total_bytes = sum([p.total for p in file_progress.values()])
|
|
@@ -110,10 +118,9 @@ def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_prog
|
|
|
|
|
|
async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]:
|
|
|
target_dir = await ensure_exo_tmp()/repo_id.replace("/", "--")
|
|
|
- async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=10, sock_read=1800, sock_connect=10)) as session:
|
|
|
- index_file = await download_file(session, repo_id, revision, "model.safetensors.index.json", target_dir)
|
|
|
- async with aiofiles.open(index_file, 'r') as f: index_data = json.loads(await f.read())
|
|
|
- return index_data.get("weight_map")
|
|
|
+ index_file = await download_file(repo_id, revision, "model.safetensors.index.json", target_dir)
|
|
|
+ async with aiofiles.open(index_file, 'r') as f: index_data = json.loads(await f.read())
|
|
|
+ return index_data.get("weight_map")
|
|
|
|
|
|
async def resolve_allow_patterns(shard: Shard, inference_engine_classname: str) -> List[str]:
|
|
|
try:
|
|
@@ -138,30 +145,32 @@ async def download_shard(shard: Shard, inference_engine_classname: str, on_progr
|
|
|
if DEBUG >= 2: print(f"Downloading {shard.model_id=} with {allow_patterns=}")
|
|
|
|
|
|
all_start_time = time.time()
|
|
|
- async with aiohttp.ClientSession() as session:
|
|
|
- file_list = await fetch_file_list(session, repo_id, revision)
|
|
|
- filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, key=lambda x: x["path"]))
|
|
|
- file_progress: Dict[str, RepoFileProgressEvent] = {}
|
|
|
- def on_progress_wrapper(file: dict, curr_bytes: int, total_bytes: int):
|
|
|
- start_time = file_progress[file["path"]].start_time if file["path"] in file_progress else time.time()
|
|
|
- downloaded_this_session = file_progress[file["path"]].downloaded_this_session + (curr_bytes - file_progress[file["path"]].downloaded) if file["path"] in file_progress else curr_bytes
|
|
|
- speed = downloaded_this_session / (time.time() - start_time)
|
|
|
- eta = timedelta(seconds=(total_bytes - curr_bytes) / speed)
|
|
|
- file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], curr_bytes, downloaded_this_session, total_bytes, speed, eta, "complete" if curr_bytes == total_bytes else "in_progress", start_time)
|
|
|
- on_progress.trigger_all(shard, calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time))
|
|
|
- if DEBUG >= 6: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}")
|
|
|
- for file in filtered_file_list:
|
|
|
- downloaded_bytes = (await aios.stat(target_dir/file["path"])).st_size if await aios.path.exists(target_dir/file["path"]) else 0
|
|
|
- file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], downloaded_bytes, 0, file["size"], 0, timedelta(0), "complete" if downloaded_bytes == file["size"] else "not_started", time.time())
|
|
|
+ file_list = await fetch_file_list(repo_id, revision)
|
|
|
+ filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, key=lambda x: x["path"]))
|
|
|
+ file_progress: Dict[str, RepoFileProgressEvent] = {}
|
|
|
+ def on_progress_wrapper(file: dict, curr_bytes: int, total_bytes: int):
|
|
|
+ start_time = file_progress[file["path"]].start_time if file["path"] in file_progress else time.time()
|
|
|
+ downloaded_this_session = file_progress[file["path"]].downloaded_this_session + (curr_bytes - file_progress[file["path"]].downloaded) if file["path"] in file_progress else curr_bytes
|
|
|
+ speed = downloaded_this_session / (time.time() - start_time)
|
|
|
+ eta = timedelta(seconds=(total_bytes - curr_bytes) / speed)
|
|
|
+ file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], curr_bytes, downloaded_this_session, total_bytes, speed, eta, "complete" if curr_bytes == total_bytes else "in_progress", start_time)
|
|
|
+ on_progress.trigger_all(shard, calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time))
|
|
|
+ if DEBUG >= 6: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}")
|
|
|
+ for file in filtered_file_list:
|
|
|
+ downloaded_bytes = (await aios.stat(target_dir/file["path"])).st_size if await aios.path.exists(target_dir/file["path"]) else 0
|
|
|
+ file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], downloaded_bytes, 0, file["size"], 0, timedelta(0), "complete" if downloaded_bytes == file["size"] else "not_started", time.time())
|
|
|
|
|
|
semaphore = asyncio.Semaphore(max_parallel_downloads)
|
|
|
async def download_with_semaphore(file):
|
|
|
async with semaphore:
|
|
|
- await download_file(session, repo_id, revision, file["path"], target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes))
|
|
|
+ await download_file(repo_id, revision, file["path"], target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes))
|
|
|
if not skip_download: await asyncio.gather(*[download_with_semaphore(file) for file in filtered_file_list])
|
|
|
final_repo_progress = calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time)
|
|
|
on_progress.trigger_all(shard, final_repo_progress)
|
|
|
- return target_dir, final_repo_progress
|
|
|
+ if gguf := next((f for f in filtered_file_list if f["path"].endswith(".gguf")), None):
|
|
|
+ return target_dir/gguf["path"], final_repo_progress
|
|
|
+ else:
|
|
|
+ return target_dir, final_repo_progress
|
|
|
|
|
|
def new_shard_downloader() -> ShardDownloader:
|
|
|
return SingletonShardDownloader(CachedShardDownloader(NewShardDownloader()))
|