1
0

hf_async.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. import asyncio
  2. import aiohttp
  3. import os
  4. import argparse
  5. from urllib.parse import urljoin
  6. from typing import Callable, Optional, Coroutine, Any
  7. from datetime import datetime, timedelta
  8. from fnmatch import fnmatch
  9. from pathlib import Path
  10. from typing import Generator, Iterable, List, TypeVar, Union
  11. T = TypeVar("T")
  12. DEFAULT_ALLOW_PATTERNS = [
  13. "*.json",
  14. "*.py",
  15. "tokenizer.model",
  16. "*.tiktoken",
  17. "*.txt",
  18. "*.safetensors",
  19. ]
  20. # Always ignore `.git` and `.cache/huggingface` folders in commits
  21. DEFAULT_IGNORE_PATTERNS = [
  22. ".git",
  23. ".git/*",
  24. "*/.git",
  25. "**/.git/**",
  26. ".cache/huggingface",
  27. ".cache/huggingface/*",
  28. "*/.cache/huggingface",
  29. "**/.cache/huggingface/**",
  30. ]
  31. def filter_repo_objects(
  32. items: Iterable[T],
  33. *,
  34. allow_patterns: Optional[Union[List[str], str]] = None,
  35. ignore_patterns: Optional[Union[List[str], str]] = None,
  36. key: Optional[Callable[[T], str]] = None,
  37. ) -> Generator[T, None, None]:
  38. if isinstance(allow_patterns, str):
  39. allow_patterns = [allow_patterns]
  40. if isinstance(ignore_patterns, str):
  41. ignore_patterns = [ignore_patterns]
  42. if allow_patterns is not None:
  43. allow_patterns = [_add_wildcard_to_directories(p) for p in allow_patterns]
  44. if ignore_patterns is not None:
  45. ignore_patterns = [_add_wildcard_to_directories(p) for p in ignore_patterns]
  46. if key is None:
  47. def _identity(item: T) -> str:
  48. if isinstance(item, str):
  49. return item
  50. if isinstance(item, Path):
  51. return str(item)
  52. raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.")
  53. key = _identity
  54. for item in items:
  55. path = key(item)
  56. if allow_patterns is not None and not any(fnmatch(path, r) for r in allow_patterns):
  57. continue
  58. if ignore_patterns is not None and any(fnmatch(path, r) for r in ignore_patterns):
  59. continue
  60. yield item
  61. def _add_wildcard_to_directories(pattern: str) -> str:
  62. if pattern[-1] == "/":
  63. return pattern + "*"
  64. return pattern
  65. def get_hf_home() -> Path:
  66. """Get the Hugging Face home directory."""
  67. return Path(os.environ.get("HF_HOME", Path.home() / ".cache" / "huggingface"))
  68. def get_hf_token():
  69. """Retrieve the Hugging Face token from the user's HF_HOME directory."""
  70. token_path = get_hf_home() / "token"
  71. if token_path.exists():
  72. return token_path.read_text().strip()
  73. return None
  74. def get_auth_headers():
  75. """Get authentication headers if a token is available."""
  76. token = get_hf_token()
  77. if token:
  78. return {"Authorization": f"Bearer {token}"}
  79. return {}
  80. def get_repo_root(repo_id: str) -> Path:
  81. """Get the root directory for a given repo ID in the Hugging Face cache."""
  82. sanitized_repo_id = repo_id.replace("/", "--")
  83. return get_hf_home() / "hub" / f"models--{sanitized_repo_id}"
  84. async def fetch_file_list(session, repo_id, revision, path=""):
  85. api_url = f"https://huggingface.co/api/models/{repo_id}/tree/{revision}"
  86. url = f"{api_url}/{path}" if path else api_url
  87. headers = get_auth_headers()
  88. async with session.get(url, headers=headers) as response:
  89. if response.status == 200:
  90. data = await response.json()
  91. files = []
  92. for item in data:
  93. if item["type"] == "file":
  94. files.append({"path": item["path"], "size": item["size"]})
  95. elif item["type"] == "directory":
  96. subfiles = await fetch_file_list(session, repo_id, revision, item["path"])
  97. files.extend(subfiles)
  98. return files
  99. else:
  100. raise Exception(f"Failed to fetch file list: {response.status}")
  101. async def download_file(session, repo_id, revision, file_path, save_directory, progress_callback: Optional[Callable[[str, int, int, float, timedelta], Coroutine[Any, Any, None]]] = None):
  102. base_url = f"https://huggingface.co/{repo_id}/resolve/{revision}/"
  103. url = urljoin(base_url, file_path)
  104. local_path = os.path.join(save_directory, file_path)
  105. os.makedirs(os.path.dirname(local_path), exist_ok=True)
  106. # Check if file already exists and get its size
  107. if os.path.exists(local_path):
  108. local_file_size = os.path.getsize(local_path)
  109. else:
  110. local_file_size = 0
  111. headers = {"Range": f"bytes={local_file_size}-"}
  112. headers.update(get_auth_headers())
  113. async with session.get(url, headers=headers) as response:
  114. if response.status == 200:
  115. # File doesn't support range requests, start from beginning
  116. mode = 'wb'
  117. total_size = int(response.headers.get('Content-Length', 0))
  118. downloaded_size = 0
  119. elif response.status == 206:
  120. # Partial content, resume download
  121. mode = 'ab'
  122. content_range = response.headers.get('Content-Range')
  123. total_size = int(content_range.split('/')[-1])
  124. downloaded_size = local_file_size
  125. elif response.status == 416:
  126. # Range not satisfiable, get the actual file size
  127. if response.headers.get('Content-Type', '').startswith('text/html'):
  128. content = await response.text()
  129. print(f"Response content (HTML):\n{content}")
  130. else:
  131. print(response)
  132. print("Return header: ", response.headers)
  133. print("Return header: ", response.headers.get('Content-Range').split('/')[-1])
  134. total_size = int(response.headers.get('Content-Range', '').split('/')[-1])
  135. if local_file_size == total_size:
  136. print(f"File already fully downloaded: {file_path}")
  137. return
  138. else:
  139. # Start the download from the beginning
  140. mode = 'wb'
  141. downloaded_size = 0
  142. else:
  143. print(f"Failed to download {file_path}: {response.status}")
  144. return
  145. if downloaded_size == total_size:
  146. print(f"File already downloaded: {file_path}")
  147. return
  148. start_time = datetime.now()
  149. new_downloaded_size = 0
  150. with open(local_path, mode) as f:
  151. async for chunk in response.content.iter_chunked(8192):
  152. f.write(chunk)
  153. new_downloaded_size += len(chunk)
  154. if progress_callback:
  155. elapsed_time = (datetime.now() - start_time).total_seconds()
  156. speed = new_downloaded_size / elapsed_time if elapsed_time > 0 else 0
  157. eta = timedelta(seconds=(total_size - downloaded_size - new_downloaded_size) / speed) if speed > 0 else timedelta(0)
  158. await progress_callback(file_path, new_downloaded_size, total_size - downloaded_size, speed, eta)
  159. print(f"Downloaded: {file_path}")
  160. async def download_all_files(repo_id, revision="main", progress_callback: Optional[Callable[[int, int, int, int, timedelta, dict], Coroutine[Any, Any, None]]] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None):
  161. repo_root = get_repo_root(repo_id)
  162. refs_dir = repo_root / "refs"
  163. snapshots_dir = repo_root / "snapshots"
  164. # Ensure directories exist
  165. refs_dir.mkdir(parents=True, exist_ok=True)
  166. snapshots_dir.mkdir(parents=True, exist_ok=True)
  167. async with aiohttp.ClientSession() as session:
  168. # Fetch the commit hash for the given revision
  169. api_url = f"https://huggingface.co/api/models/{repo_id}/revision/{revision}"
  170. headers = get_auth_headers()
  171. async with session.get(api_url, headers=headers) as response:
  172. if response.status != 200:
  173. raise Exception(f"Failed to fetch revision info: {response.status}")
  174. revision_info = await response.json()
  175. commit_hash = revision_info['sha']
  176. # Write the commit hash to the refs file
  177. refs_file = refs_dir / revision
  178. refs_file.write_text(commit_hash)
  179. # Set up the snapshot directory
  180. snapshot_dir = snapshots_dir / commit_hash
  181. snapshot_dir.mkdir(exist_ok=True)
  182. file_list = await fetch_file_list(session, repo_id, revision)
  183. filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
  184. total_files = len(filtered_file_list)
  185. completed_files = 0
  186. total_bytes = sum(file["size"] for file in filtered_file_list)
  187. downloaded_bytes = 0
  188. new_downloaded_bytes = 0
  189. file_progress = {file["path"]: {"status": "not_started", "downloaded": 0, "total": file["size"]} for file in filtered_file_list}
  190. start_time = datetime.now()
  191. async def download_with_progress(file_info):
  192. nonlocal completed_files, downloaded_bytes, new_downloaded_bytes, file_progress
  193. async def file_progress_callback(path, file_downloaded, file_total, speed, file_eta):
  194. nonlocal downloaded_bytes, new_downloaded_bytes, file_progress
  195. new_downloaded_bytes += file_downloaded - file_progress[path]['downloaded']
  196. downloaded_bytes += file_downloaded - file_progress[path]['downloaded']
  197. file_progress[path].update({
  198. 'status': 'in_progress',
  199. 'downloaded': file_downloaded,
  200. 'total': file_total,
  201. 'speed': speed,
  202. 'eta': file_eta
  203. })
  204. if progress_callback:
  205. elapsed_time = (datetime.now() - start_time).total_seconds()
  206. overall_speed = new_downloaded_bytes / elapsed_time if elapsed_time > 0 else 0
  207. overall_eta = timedelta(seconds=(total_bytes - downloaded_bytes) / overall_speed) if overall_speed > 0 else timedelta(0)
  208. await progress_callback(completed_files, total_files, new_downloaded_bytes, total_bytes, overall_eta, file_progress)
  209. await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
  210. completed_files += 1
  211. file_progress[file_info["path"]]['status'] = 'complete'
  212. if progress_callback:
  213. elapsed_time = (datetime.now() - start_time).total_seconds()
  214. overall_speed = new_downloaded_bytes / elapsed_time if elapsed_time > 0 else 0
  215. overall_eta = timedelta(seconds=(total_bytes - downloaded_bytes) / overall_speed) if overall_speed > 0 else timedelta(0)
  216. await progress_callback(completed_files, total_files, new_downloaded_bytes, total_bytes, overall_eta, file_progress)
  217. tasks = [download_with_progress(file_info) for file_info in filtered_file_list]
  218. await asyncio.gather(*tasks)
  219. async def main(repo_id, revision="main", allow_patterns=None, ignore_patterns=None):
  220. async def progress_callback(completed_files, total_files, downloaded_bytes, total_bytes, overall_eta, file_progress):
  221. print(f"Overall Progress: {completed_files}/{total_files} files, {downloaded_bytes}/{total_bytes} bytes")
  222. print(f"Estimated time remaining: {overall_eta}")
  223. print("File Progress:")
  224. for file_path, progress in file_progress.items():
  225. status_icon = {
  226. 'not_started': '⚪',
  227. 'in_progress': '🔵',
  228. 'complete': '✅'
  229. }[progress['status']]
  230. eta_str = str(progress.get('eta', 'N/A'))
  231. print(f"{status_icon} {file_path}: {progress.get('downloaded', 0)}/{progress['total']} bytes, "
  232. f"Speed: {progress.get('speed', 0):.2f} B/s, ETA: {eta_str}")
  233. print("\n")
  234. await download_all_files(repo_id, revision, progress_callback, allow_patterns, ignore_patterns)
  235. if __name__ == "__main__":
  236. parser = argparse.ArgumentParser(description="Download files from a Hugging Face model repository.")
  237. parser.add_argument("--repo-id", help="The repository ID (e.g., 'meta-llama/Meta-Llama-3.1-8B-Instruct')")
  238. parser.add_argument("--revision", default="main", help="The revision to download (branch, tag, or commit hash)")
  239. parser.add_argument("--allow-patterns", nargs="*", default=DEFAULT_ALLOW_PATTERNS, help="Patterns of files to allow (e.g., '*.json' '*.safetensors')")
  240. parser.add_argument("--ignore-patterns", nargs="*", default=DEFAULT_IGNORE_PATTERNS, help="Patterns of files to ignore (e.g., '.*')")
  241. args = parser.parse_args()
  242. asyncio.run(main(args.repo_id, args.revision, args.allow_patterns, args.ignore_patterns))