|
@@ -12,6 +12,8 @@ from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_excep
|
|
|
from exo.helpers import DEBUG
|
|
|
from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback
|
|
|
from exo.inference.shard import Shard
|
|
|
+import aiofiles
|
|
|
+from aiofiles import os as aios
|
|
|
|
|
|
T = TypeVar("T")
|
|
|
def filter_repo_objects(
|
|
@@ -56,16 +58,17 @@ def get_hf_home() -> Path:
|
|
|
"""Get the Hugging Face home directory."""
|
|
|
return Path(os.environ.get("HF_HOME", Path.home() / ".cache" / "huggingface"))
|
|
|
|
|
|
-def get_hf_token():
|
|
|
+async def get_hf_token():
|
|
|
"""Retrieve the Hugging Face token from the user's HF_HOME directory."""
|
|
|
token_path = get_hf_home() / "token"
|
|
|
- if token_path.exists():
|
|
|
- return token_path.read_text().strip()
|
|
|
+ if await aios.path.exists(token_path):
|
|
|
+ async with aiofiles.open(token_path, 'r') as f:
|
|
|
+ return (await f.read()).strip()
|
|
|
return None
|
|
|
|
|
|
-def get_auth_headers():
|
|
|
+async def get_auth_headers():
|
|
|
"""Get authentication headers if a token is available."""
|
|
|
- token = get_hf_token()
|
|
|
+ token = await get_hf_token()
|
|
|
if token:
|
|
|
return {"Authorization": f"Bearer {token}"}
|
|
|
return {}
|
|
@@ -79,7 +82,7 @@ async def fetch_file_list(session, repo_id, revision, path=""):
|
|
|
api_url = f"https://huggingface.co/api/models/{repo_id}/tree/{revision}"
|
|
|
url = f"{api_url}/{path}" if path else api_url
|
|
|
|
|
|
- headers = get_auth_headers()
|
|
|
+ headers = await get_auth_headers()
|
|
|
async with session.get(url, headers=headers) as response:
|
|
|
if response.status == 200:
|
|
|
data = await response.json()
|
|
@@ -106,12 +109,12 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
|
|
|
url = urljoin(base_url, file_path)
|
|
|
local_path = os.path.join(save_directory, file_path)
|
|
|
|
|
|
- os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
|
|
+ await aios.makedirs(os.path.dirname(local_path), exist_ok=True)
|
|
|
|
|
|
# Check if file already exists and get its size
|
|
|
- local_file_size = os.path.getsize(local_path) if os.path.exists(local_path) else 0
|
|
|
+ local_file_size = await aios.path.getsize(local_path) if await aios.path.exists(local_path) else 0
|
|
|
|
|
|
- headers = get_auth_headers()
|
|
|
+ headers = await get_auth_headers()
|
|
|
if use_range_request:
|
|
|
headers["Range"] = f"bytes={local_file_size}-"
|
|
|
|
|
@@ -162,9 +165,9 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
|
|
|
|
|
|
DOWNLOAD_CHUNK_SIZE = 32768
|
|
|
start_time = datetime.now()
|
|
|
- with open(local_path, mode) as f:
|
|
|
+ async with aiofiles.open(local_path, mode) as f:
|
|
|
async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
|
|
|
- f.write(chunk)
|
|
|
+ await f.write(chunk)
|
|
|
downloaded_size += len(chunk)
|
|
|
downloaded_this_session += len(chunk)
|
|
|
if progress_callback and total_size:
|
|
@@ -177,34 +180,60 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
|
|
|
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
|
|
|
if DEBUG >= 2: print(f"Downloaded: {file_path}")
|
|
|
|
|
|
-async def download_repo_files(repo_id: str, revision: str = "main", progress_callback: Optional[RepoProgressCallback] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None) -> Path:
|
|
|
+async def download_repo_files(repo_id: str, revision: str = "main", progress_callback: Optional[RepoProgressCallback] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None, max_parallel_downloads: int = 4) -> Path:
|
|
|
repo_root = get_repo_root(repo_id)
|
|
|
refs_dir = repo_root / "refs"
|
|
|
snapshots_dir = repo_root / "snapshots"
|
|
|
+ cachedreqs_dir = repo_root / "cachedreqs"
|
|
|
|
|
|
# Ensure directories exist
|
|
|
- refs_dir.mkdir(parents=True, exist_ok=True)
|
|
|
- snapshots_dir.mkdir(parents=True, exist_ok=True)
|
|
|
+ await aios.makedirs(refs_dir, exist_ok=True)
|
|
|
+ await aios.makedirs(snapshots_dir, exist_ok=True)
|
|
|
+ await aios.makedirs(cachedreqs_dir, exist_ok=True)
|
|
|
+
|
|
|
+ # Check if we have a cached commit hash
|
|
|
+ refs_file = refs_dir / revision
|
|
|
+ if await aios.path.exists(refs_file):
|
|
|
+ async with aiofiles.open(refs_file, 'r') as f:
|
|
|
+ commit_hash = (await f.read()).strip()
|
|
|
+ if DEBUG >= 2: print(f"Commit hash is already hashed at {refs_file}: {commit_hash}")
|
|
|
+ else:
|
|
|
+ async with aiohttp.ClientSession() as session:
|
|
|
+ # Fetch the commit hash for the given revision
|
|
|
+ api_url = f"https://huggingface.co/api/models/{repo_id}/revision/{revision}"
|
|
|
+ headers = await get_auth_headers()
|
|
|
+ async with session.get(api_url, headers=headers) as response:
|
|
|
+ if response.status != 200:
|
|
|
+ raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
|
|
|
+ revision_info = await response.json()
|
|
|
+ commit_hash = revision_info['sha']
|
|
|
+
|
|
|
+ # Cache the commit hash
|
|
|
+ async with aiofiles.open(refs_file, 'w') as f:
|
|
|
+ await f.write(commit_hash)
|
|
|
+
|
|
|
+ # Set up the snapshot directory
|
|
|
+ snapshot_dir = snapshots_dir / commit_hash
|
|
|
+ await aios.makedirs(snapshot_dir, exist_ok=True)
|
|
|
+
|
|
|
+ # Set up the cached file list directory
|
|
|
+ cached_file_list_dir = cachedreqs_dir / commit_hash
|
|
|
+ await aios.makedirs(cached_file_list_dir, exist_ok=True)
|
|
|
+ cached_file_list_path = cached_file_list_dir / "fetch_file_list.json"
|
|
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
|
- # Fetch the commit hash for the given revision
|
|
|
- api_url = f"https://huggingface.co/api/models/{repo_id}/revision/{revision}"
|
|
|
- headers = get_auth_headers()
|
|
|
- async with session.get(api_url, headers=headers) as response:
|
|
|
- if response.status != 200:
|
|
|
- raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
|
|
|
- revision_info = await response.json()
|
|
|
- commit_hash = revision_info['sha']
|
|
|
-
|
|
|
- # Write the commit hash to the refs file
|
|
|
- refs_file = refs_dir / revision
|
|
|
- refs_file.write_text(commit_hash)
|
|
|
-
|
|
|
- # Set up the snapshot directory
|
|
|
- snapshot_dir = snapshots_dir / commit_hash
|
|
|
- snapshot_dir.mkdir(exist_ok=True)
|
|
|
-
|
|
|
- file_list = await fetch_file_list(session, repo_id, revision)
|
|
|
+ # Check if we have a cached file list
|
|
|
+ if await aios.path.exists(cached_file_list_path):
|
|
|
+ async with aiofiles.open(cached_file_list_path, 'r') as f:
|
|
|
+ file_list = json.loads(await f.read())
|
|
|
+ if DEBUG >= 2: print(f"Using cached file list from {cached_file_list_path}")
|
|
|
+ else:
|
|
|
+ file_list = await fetch_file_list(session, repo_id, revision)
|
|
|
+ # Cache the file list
|
|
|
+ async with aiofiles.open(cached_file_list_path, 'w') as f:
|
|
|
+ await f.write(json.dumps(file_list))
|
|
|
+ if DEBUG >= 2: print(f"Cached file list at {cached_file_list_path}")
|
|
|
+
|
|
|
filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
|
|
|
total_files = len(filtered_file_list)
|
|
|
total_bytes = sum(file["size"] for file in filtered_file_list)
|
|
@@ -212,6 +241,21 @@ async def download_repo_files(repo_id: str, revision: str = "main", progress_cal
|
|
|
start_time = datetime.now()
|
|
|
|
|
|
async def download_with_progress(file_info, progress_state):
|
|
|
+ local_path = snapshot_dir / file_info["path"]
|
|
|
+ if await aios.path.exists(local_path) and (await aios.stat(local_path)).st_size == file_info["size"]:
|
|
|
+ if DEBUG >= 2: print(f"File already fully downloaded: {file_info['path']}")
|
|
|
+ progress_state['completed_files'] += 1
|
|
|
+ progress_state['downloaded_bytes'] += file_info["size"]
|
|
|
+ file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], 0, file_info["size"], 0, timedelta(0), "complete")
|
|
|
+ if progress_callback:
|
|
|
+ elapsed_time = (datetime.now() - start_time).total_seconds()
|
|
|
+ overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
|
|
|
+ remaining_bytes = total_bytes - progress_state['downloaded_bytes']
|
|
|
+ overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
|
|
|
+ status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
|
|
|
+ await progress_callback(RepoProgressEvent(repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
|
|
|
+ return
|
|
|
+
|
|
|
async def file_progress_callback(event: RepoFileProgressEvent):
|
|
|
progress_state['downloaded_bytes'] += event.downloaded - file_progress[event.file_path].downloaded
|
|
|
progress_state['downloaded_bytes_this_session'] += event.downloaded_this_session - file_progress[event.file_path].downloaded_this_session
|
|
@@ -236,7 +280,12 @@ async def download_repo_files(repo_id: str, revision: str = "main", progress_cal
|
|
|
await progress_callback(RepoProgressEvent(repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
|
|
|
|
|
|
progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0}
|
|
|
- tasks = [download_with_progress(file_info, progress_state) for file_info in filtered_file_list]
|
|
|
+
|
|
|
+ semaphore = asyncio.Semaphore(max_parallel_downloads)
|
|
|
+ async def download_with_semaphore(file_info):
|
|
|
+ async with semaphore:
|
|
|
+ await download_with_progress(file_info, progress_state)
|
|
|
+ tasks = [asyncio.create_task(download_with_semaphore(file_info)) for file_info in filtered_file_list]
|
|
|
await asyncio.gather(*tasks)
|
|
|
|
|
|
return snapshot_dir
|
|
@@ -263,12 +312,14 @@ async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[
|
|
|
# Check if the file exists
|
|
|
repo_root = get_repo_root(repo_id)
|
|
|
snapshot_dir = repo_root / "snapshots"
|
|
|
- index_file = next(snapshot_dir.glob("*/model.safetensors.index.json"), None)
|
|
|
-
|
|
|
- if index_file and index_file.exists():
|
|
|
- with open(index_file, 'r') as f:
|
|
|
- index_data = json.load(f)
|
|
|
- return index_data.get("weight_map")
|
|
|
+ index_file = next((f for f in await aios.listdir(snapshot_dir) if f.endswith("model.safetensors.index.json")), None)
|
|
|
+
|
|
|
+ if index_file:
|
|
|
+ index_file_path = snapshot_dir / index_file
|
|
|
+ if await aios.path.exists(index_file_path):
|
|
|
+ async with aiofiles.open(index_file_path, 'r') as f:
|
|
|
+ index_data = json.loads(await f.read())
|
|
|
+ return index_data.get("weight_map")
|
|
|
|
|
|
return None
|
|
|
|