|
@@ -1,36 +1,17 @@
|
|
import asyncio
|
|
import asyncio
|
|
import aiohttp
|
|
import aiohttp
|
|
import os
|
|
import os
|
|
-import argparse
|
|
|
|
from urllib.parse import urljoin
|
|
from urllib.parse import urljoin
|
|
-from typing import Callable, Optional, Coroutine, Any
|
|
|
|
|
|
+from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
|
|
from datetime import datetime, timedelta
|
|
from datetime import datetime, timedelta
|
|
from fnmatch import fnmatch
|
|
from fnmatch import fnmatch
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
-from typing import Generator, Iterable, List, TypeVar, Union
|
|
|
|
|
|
+from typing import Generator, Iterable, TypeVar, TypedDict
|
|
|
|
+from dataclasses import dataclass
|
|
|
|
+from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
|
|
|
+from exo.helpers import DEBUG
|
|
|
|
|
|
T = TypeVar("T")
|
|
T = TypeVar("T")
|
|
-
|
|
|
|
-DEFAULT_ALLOW_PATTERNS = [
|
|
|
|
- "*.json",
|
|
|
|
- "*.py",
|
|
|
|
- "tokenizer.model",
|
|
|
|
- "*.tiktoken",
|
|
|
|
- "*.txt",
|
|
|
|
- "*.safetensors",
|
|
|
|
-]
|
|
|
|
-# Always ignore `.git` and `.cache/huggingface` folders in commits
|
|
|
|
-DEFAULT_IGNORE_PATTERNS = [
|
|
|
|
- ".git",
|
|
|
|
- ".git/*",
|
|
|
|
- "*/.git",
|
|
|
|
- "**/.git/**",
|
|
|
|
- ".cache/huggingface",
|
|
|
|
- ".cache/huggingface/*",
|
|
|
|
- "*/.cache/huggingface",
|
|
|
|
- "**/.cache/huggingface/**",
|
|
|
|
-]
|
|
|
|
-
|
|
|
|
def filter_repo_objects(
|
|
def filter_repo_objects(
|
|
items: Iterable[T],
|
|
items: Iterable[T],
|
|
*,
|
|
*,
|
|
@@ -117,7 +98,35 @@ async def fetch_file_list(session, repo_id, revision, path=""):
|
|
else:
|
|
else:
|
|
raise Exception(f"Failed to fetch file list: {response.status}")
|
|
raise Exception(f"Failed to fetch file list: {response.status}")
|
|
|
|
|
|
-async def download_file(session, repo_id, revision, file_path, save_directory, progress_callback: Optional[Callable[[str, int, int, float, timedelta], Coroutine[Any, Any, None]]] = None):
|
|
|
|
|
|
+
|
|
|
|
+@dataclass
|
|
|
|
+class HFRepoFileProgressEvent:
|
|
|
|
+ file_path: str
|
|
|
|
+ downloaded: int
|
|
|
|
+ total: int
|
|
|
|
+ speed: float
|
|
|
|
+ eta: timedelta
|
|
|
|
+ status: Literal["not_started", "in_progress", "complete"]
|
|
|
|
+
|
|
|
|
+@dataclass
|
|
|
|
+class HFRepoProgressEvent:
|
|
|
|
+ completed_files: int
|
|
|
|
+ total_files: int
|
|
|
|
+ downloaded_bytes: int
|
|
|
|
+ total_bytes: int
|
|
|
|
+ overall_eta: timedelta
|
|
|
|
+ file_progress: Dict[str, HFRepoFileProgressEvent]
|
|
|
|
+
|
|
|
|
+HFRepoFileProgressCallback = Callable[[HFRepoFileProgressEvent], Coroutine[Any, Any, None]]
|
|
|
|
+HFRepoProgressCallback = Callable[[HFRepoProgressEvent], Coroutine[Any, Any, None]]
|
|
|
|
+
|
|
|
|
+@retry(
|
|
|
|
+ stop=stop_after_attempt(5),
|
|
|
|
+ wait=wait_exponential(multiplier=1, min=4, max=60),
|
|
|
|
+ retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError, aiohttp.ClientResponseError)),
|
|
|
|
+ reraise=True
|
|
|
|
+)
|
|
|
|
+async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[HFRepoFileProgressCallback] = None, use_range_request: bool = True):
|
|
base_url = f"https://huggingface.co/{repo_id}/resolve/{revision}/"
|
|
base_url = f"https://huggingface.co/{repo_id}/resolve/{revision}/"
|
|
url = urljoin(base_url, file_path)
|
|
url = urljoin(base_url, file_path)
|
|
local_path = os.path.join(save_directory, file_path)
|
|
local_path = os.path.join(save_directory, file_path)
|
|
@@ -125,64 +134,72 @@ async def download_file(session, repo_id, revision, file_path, save_directory, p
|
|
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
|
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
|
|
|
|
|
# Check if file already exists and get its size
|
|
# Check if file already exists and get its size
|
|
- if os.path.exists(local_path):
|
|
|
|
- local_file_size = os.path.getsize(local_path)
|
|
|
|
- else:
|
|
|
|
- local_file_size = 0
|
|
|
|
|
|
+ local_file_size = os.path.getsize(local_path) if os.path.exists(local_path) else 0
|
|
|
|
+
|
|
|
|
+ headers = get_auth_headers()
|
|
|
|
+ if use_range_request:
|
|
|
|
+ headers["Range"] = f"bytes={local_file_size}-"
|
|
|
|
|
|
- headers = {"Range": f"bytes={local_file_size}-"}
|
|
|
|
- headers.update(get_auth_headers())
|
|
|
|
async with session.get(url, headers=headers) as response:
|
|
async with session.get(url, headers=headers) as response:
|
|
|
|
+ total_size = int(response.headers.get('Content-Length', 0))
|
|
|
|
+ downloaded_size = local_file_size
|
|
|
|
+ mode = 'ab' if use_range_request else 'wb'
|
|
|
|
+ if downloaded_size == total_size:
|
|
|
|
+ if DEBUG >= 2: print(f"File already downloaded: {file_path}")
|
|
|
|
+ if progress_callback:
|
|
|
|
+ await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, total_size, 0, timedelta(0), "complete"))
|
|
|
|
+ return
|
|
|
|
+
|
|
if response.status == 200:
|
|
if response.status == 200:
|
|
- # File doesn't support range requests, start from beginning
|
|
|
|
|
|
+ # File doesn't support range requests or we're not using them, start from beginning
|
|
mode = 'wb'
|
|
mode = 'wb'
|
|
- total_size = int(response.headers.get('Content-Length', 0))
|
|
|
|
downloaded_size = 0
|
|
downloaded_size = 0
|
|
elif response.status == 206:
|
|
elif response.status == 206:
|
|
# Partial content, resume download
|
|
# Partial content, resume download
|
|
- mode = 'ab'
|
|
|
|
- content_range = response.headers.get('Content-Range')
|
|
|
|
- total_size = int(content_range.split('/')[-1])
|
|
|
|
- downloaded_size = local_file_size
|
|
|
|
|
|
+ content_range = response.headers.get('Content-Range', '')
|
|
|
|
+ try:
|
|
|
|
+ total_size = int(content_range.split('/')[-1])
|
|
|
|
+ except ValueError:
|
|
|
|
+ if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
|
|
|
|
+ return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
|
|
elif response.status == 416:
|
|
elif response.status == 416:
|
|
# Range not satisfiable, get the actual file size
|
|
# Range not satisfiable, get the actual file size
|
|
- if response.headers.get('Content-Type', '').startswith('text/html'):
|
|
|
|
- content = await response.text()
|
|
|
|
- print(f"Response content (HTML):\n{content}")
|
|
|
|
- else:
|
|
|
|
- print(response)
|
|
|
|
- print("Return header: ", response.headers)
|
|
|
|
- print("Return header: ", response.headers.get('Content-Range').split('/')[-1])
|
|
|
|
- total_size = int(response.headers.get('Content-Range', '').split('/')[-1])
|
|
|
|
- if local_file_size == total_size:
|
|
|
|
- print(f"File already fully downloaded: {file_path}")
|
|
|
|
- return
|
|
|
|
- else:
|
|
|
|
- # Start the download from the beginning
|
|
|
|
- mode = 'wb'
|
|
|
|
- downloaded_size = 0
|
|
|
|
|
|
+ content_range = response.headers.get('Content-Range', '')
|
|
|
|
+ try:
|
|
|
|
+ total_size = int(content_range.split('/')[-1])
|
|
|
|
+ if downloaded_size == total_size:
|
|
|
|
+ if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}")
|
|
|
|
+ if progress_callback:
|
|
|
|
+ await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, total_size, 0, timedelta(0), "complete"))
|
|
|
|
+ return
|
|
|
|
+ except ValueError:
|
|
|
|
+ if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
|
|
|
|
+ return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
|
|
else:
|
|
else:
|
|
- print(f"Failed to download {file_path}: {response.status}")
|
|
|
|
- return
|
|
|
|
|
|
+ raise aiohttp.ClientResponseError(response.request_info, response.history, status=response.status, message=f"Failed to download {file_path}: {response.status}")
|
|
|
|
|
|
if downloaded_size == total_size:
|
|
if downloaded_size == total_size:
|
|
print(f"File already downloaded: {file_path}")
|
|
print(f"File already downloaded: {file_path}")
|
|
|
|
+ if progress_callback:
|
|
|
|
+ await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, total_size, 0, timedelta(0), "complete"))
|
|
return
|
|
return
|
|
|
|
|
|
|
|
+ DOWNLOAD_CHUNK_SIZE = 32768
|
|
start_time = datetime.now()
|
|
start_time = datetime.now()
|
|
- new_downloaded_size = 0
|
|
|
|
with open(local_path, mode) as f:
|
|
with open(local_path, mode) as f:
|
|
- async for chunk in response.content.iter_chunked(8192):
|
|
|
|
|
|
+ async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
|
|
f.write(chunk)
|
|
f.write(chunk)
|
|
- new_downloaded_size += len(chunk)
|
|
|
|
- if progress_callback:
|
|
|
|
|
|
+ downloaded_size += len(chunk)
|
|
|
|
+ if progress_callback and total_size:
|
|
elapsed_time = (datetime.now() - start_time).total_seconds()
|
|
elapsed_time = (datetime.now() - start_time).total_seconds()
|
|
- speed = new_downloaded_size / elapsed_time if elapsed_time > 0 else 0
|
|
|
|
- eta = timedelta(seconds=(total_size - downloaded_size - new_downloaded_size) / speed) if speed > 0 else timedelta(0)
|
|
|
|
- await progress_callback(file_path, new_downloaded_size, total_size - downloaded_size, speed, eta)
|
|
|
|
- print(f"Downloaded: {file_path}")
|
|
|
|
-
|
|
|
|
-async def download_all_files(repo_id, revision="main", progress_callback: Optional[Callable[[int, int, int, int, timedelta, dict], Coroutine[Any, Any, None]]] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None):
|
|
|
|
|
|
+ speed = downloaded_size / elapsed_time if elapsed_time > 0 else 0
|
|
|
|
+ remaining_size = total_size - downloaded_size
|
|
|
|
+ eta = timedelta(seconds=remaining_size / speed) if speed > 0 else timedelta(0)
|
|
|
|
+ status = "in_progress" if downloaded_size < total_size else "complete"
|
|
|
|
+ await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, total_size, speed, eta, status))
|
|
|
|
+ if DEBUG >= 2: print(f"Downloaded: {file_path}")
|
|
|
|
+
|
|
|
|
+async def download_all_files(repo_id: str, revision: str = "main", progress_callback: Optional[HFRepoProgressCallback] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None):
|
|
repo_root = get_repo_root(repo_id)
|
|
repo_root = get_repo_root(repo_id)
|
|
refs_dir = repo_root / "refs"
|
|
refs_dir = repo_root / "refs"
|
|
snapshots_dir = repo_root / "snapshots"
|
|
snapshots_dir = repo_root / "snapshots"
|
|
@@ -197,7 +214,7 @@ async def download_all_files(repo_id, revision="main", progress_callback: Option
|
|
headers = get_auth_headers()
|
|
headers = get_auth_headers()
|
|
async with session.get(api_url, headers=headers) as response:
|
|
async with session.get(api_url, headers=headers) as response:
|
|
if response.status != 200:
|
|
if response.status != 200:
|
|
- raise Exception(f"Failed to fetch revision info: {response.status}")
|
|
|
|
|
|
+ raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
|
|
revision_info = await response.json()
|
|
revision_info = await response.json()
|
|
commit_hash = revision_info['sha']
|
|
commit_hash = revision_info['sha']
|
|
|
|
|
|
@@ -215,68 +232,32 @@ async def download_all_files(repo_id, revision="main", progress_callback: Option
|
|
completed_files = 0
|
|
completed_files = 0
|
|
total_bytes = sum(file["size"] for file in filtered_file_list)
|
|
total_bytes = sum(file["size"] for file in filtered_file_list)
|
|
downloaded_bytes = 0
|
|
downloaded_bytes = 0
|
|
- new_downloaded_bytes = 0
|
|
|
|
- file_progress = {file["path"]: {"status": "not_started", "downloaded": 0, "total": file["size"]} for file in filtered_file_list}
|
|
|
|
|
|
+ file_progress: Dict[str, HFRepoFileProgressEvent] = {file["path"]: HFRepoFileProgressEvent(file["path"], 0, file["size"], 0, timedelta(0), "not_started") for file in filtered_file_list}
|
|
start_time = datetime.now()
|
|
start_time = datetime.now()
|
|
|
|
|
|
async def download_with_progress(file_info):
|
|
async def download_with_progress(file_info):
|
|
- nonlocal completed_files, downloaded_bytes, new_downloaded_bytes, file_progress
|
|
|
|
-
|
|
|
|
- async def file_progress_callback(path, file_downloaded, file_total, speed, file_eta):
|
|
|
|
- nonlocal downloaded_bytes, new_downloaded_bytes, file_progress
|
|
|
|
- new_downloaded_bytes += file_downloaded - file_progress[path]['downloaded']
|
|
|
|
- downloaded_bytes += file_downloaded - file_progress[path]['downloaded']
|
|
|
|
- file_progress[path].update({
|
|
|
|
- 'status': 'in_progress',
|
|
|
|
- 'downloaded': file_downloaded,
|
|
|
|
- 'total': file_total,
|
|
|
|
- 'speed': speed,
|
|
|
|
- 'eta': file_eta
|
|
|
|
- })
|
|
|
|
|
|
+ nonlocal completed_files, downloaded_bytes, file_progress
|
|
|
|
+
|
|
|
|
+ async def file_progress_callback(event: HFRepoFileProgressEvent):
|
|
|
|
+ nonlocal downloaded_bytes, file_progress
|
|
|
|
+ downloaded_bytes += event.downloaded - file_progress[event.file_path].downloaded
|
|
|
|
+ file_progress[event.file_path] = event
|
|
if progress_callback:
|
|
if progress_callback:
|
|
elapsed_time = (datetime.now() - start_time).total_seconds()
|
|
elapsed_time = (datetime.now() - start_time).total_seconds()
|
|
- overall_speed = new_downloaded_bytes / elapsed_time if elapsed_time > 0 else 0
|
|
|
|
|
|
+ overall_speed = downloaded_bytes / elapsed_time if elapsed_time > 0 else 0
|
|
overall_eta = timedelta(seconds=(total_bytes - downloaded_bytes) / overall_speed) if overall_speed > 0 else timedelta(0)
|
|
overall_eta = timedelta(seconds=(total_bytes - downloaded_bytes) / overall_speed) if overall_speed > 0 else timedelta(0)
|
|
- await progress_callback(completed_files, total_files, new_downloaded_bytes, total_bytes, overall_eta, file_progress)
|
|
|
|
|
|
+ await progress_callback(HFRepoProgressEvent(completed_files, total_files, downloaded_bytes, total_bytes, overall_eta, file_progress))
|
|
|
|
|
|
await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
|
|
await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
|
|
completed_files += 1
|
|
completed_files += 1
|
|
- file_progress[file_info["path"]]['status'] = 'complete'
|
|
|
|
|
|
+ file_progress[file_info["path"]] = HFRepoFileProgressEvent(file_info["path"], file_info["size"], file_info["size"], 0, timedelta(0), "complete")
|
|
if progress_callback:
|
|
if progress_callback:
|
|
elapsed_time = (datetime.now() - start_time).total_seconds()
|
|
elapsed_time = (datetime.now() - start_time).total_seconds()
|
|
- overall_speed = new_downloaded_bytes / elapsed_time if elapsed_time > 0 else 0
|
|
|
|
|
|
+ overall_speed = downloaded_bytes / elapsed_time if elapsed_time > 0 else 0
|
|
overall_eta = timedelta(seconds=(total_bytes - downloaded_bytes) / overall_speed) if overall_speed > 0 else timedelta(0)
|
|
overall_eta = timedelta(seconds=(total_bytes - downloaded_bytes) / overall_speed) if overall_speed > 0 else timedelta(0)
|
|
- await progress_callback(completed_files, total_files, new_downloaded_bytes, total_bytes, overall_eta, file_progress)
|
|
|
|
|
|
+ await progress_callback(HFRepoProgressEvent(completed_files, total_files, downloaded_bytes, total_bytes, overall_eta, file_progress))
|
|
|
|
|
|
tasks = [download_with_progress(file_info) for file_info in filtered_file_list]
|
|
tasks = [download_with_progress(file_info) for file_info in filtered_file_list]
|
|
await asyncio.gather(*tasks)
|
|
await asyncio.gather(*tasks)
|
|
|
|
|
|
-async def main(repo_id, revision="main", allow_patterns=None, ignore_patterns=None):
|
|
|
|
- async def progress_callback(completed_files, total_files, downloaded_bytes, total_bytes, overall_eta, file_progress):
|
|
|
|
- print(f"Overall Progress: {completed_files}/{total_files} files, {downloaded_bytes}/{total_bytes} bytes")
|
|
|
|
- print(f"Estimated time remaining: {overall_eta}")
|
|
|
|
- print("File Progress:")
|
|
|
|
- for file_path, progress in file_progress.items():
|
|
|
|
- status_icon = {
|
|
|
|
- 'not_started': '⚪',
|
|
|
|
- 'in_progress': '🔵',
|
|
|
|
- 'complete': '✅'
|
|
|
|
- }[progress['status']]
|
|
|
|
- eta_str = str(progress.get('eta', 'N/A'))
|
|
|
|
- print(f"{status_icon} {file_path}: {progress.get('downloaded', 0)}/{progress['total']} bytes, "
|
|
|
|
- f"Speed: {progress.get('speed', 0):.2f} B/s, ETA: {eta_str}")
|
|
|
|
- print("\n")
|
|
|
|
-
|
|
|
|
- await download_all_files(repo_id, revision, progress_callback, allow_patterns, ignore_patterns)
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-if __name__ == "__main__":
|
|
|
|
- parser = argparse.ArgumentParser(description="Download files from a Hugging Face model repository.")
|
|
|
|
- parser.add_argument("--repo-id", help="The repository ID (e.g., 'meta-llama/Meta-Llama-3.1-8B-Instruct')")
|
|
|
|
- parser.add_argument("--revision", default="main", help="The revision to download (branch, tag, or commit hash)")
|
|
|
|
- parser.add_argument("--allow-patterns", nargs="*", default=DEFAULT_ALLOW_PATTERNS, help="Patterns of files to allow (e.g., '*.json' '*.safetensors')")
|
|
|
|
- parser.add_argument("--ignore-patterns", nargs="*", default=DEFAULT_IGNORE_PATTERNS, help="Patterns of files to ignore (e.g., '.*')")
|
|
|
|
-
|
|
|
|
- args = parser.parse_args()
|
|
|
|
-
|
|
|
|
- asyncio.run(main(args.repo_id, args.revision, args.allow_patterns, args.ignore_patterns))
|
|
|
|
|
|
+ return snapshot_dir
|