hf_helpers.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. import aiofiles.os as aios
  2. from typing import Union
  3. import asyncio
  4. import aiohttp
  5. import json
  6. import os
  7. import sys
  8. import shutil
  9. from urllib.parse import urljoin
  10. from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
  11. from datetime import datetime, timedelta
  12. from fnmatch import fnmatch
  13. from pathlib import Path
  14. from typing import Generator, Iterable, TypeVar, TypedDict
  15. from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
  16. from exo.helpers import DEBUG, is_frozen
  17. from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback
  18. from exo.inference.shard import Shard
  19. import aiofiles
  20. from aiofiles import os as aios
  21. T = TypeVar("T")
  22. async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]:
  23. refs_dir = get_repo_root(repo_id)/"refs"
  24. refs_file = refs_dir/revision
  25. if await aios.path.exists(refs_file):
  26. async with aiofiles.open(refs_file, 'r') as f:
  27. commit_hash = (await f.read()).strip()
  28. snapshot_dir = get_repo_root(repo_id)/"snapshots"/commit_hash
  29. return snapshot_dir
  30. return None
  31. def filter_repo_objects(
  32. items: Iterable[T],
  33. *,
  34. allow_patterns: Optional[Union[List[str], str]] = None,
  35. ignore_patterns: Optional[Union[List[str], str]] = None,
  36. key: Optional[Callable[[T], str]] = None,
  37. ) -> Generator[T, None, None]:
  38. if isinstance(allow_patterns, str):
  39. allow_patterns = [allow_patterns]
  40. if isinstance(ignore_patterns, str):
  41. ignore_patterns = [ignore_patterns]
  42. if allow_patterns is not None:
  43. allow_patterns = [_add_wildcard_to_directories(p) for p in allow_patterns]
  44. if ignore_patterns is not None:
  45. ignore_patterns = [_add_wildcard_to_directories(p) for p in ignore_patterns]
  46. if key is None:
  47. def _identity(item: T) -> str:
  48. if isinstance(item, str):
  49. return item
  50. if isinstance(item, Path):
  51. return str(item)
  52. raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.")
  53. key = _identity
  54. for item in items:
  55. path = key(item)
  56. if allow_patterns is not None and not any(fnmatch(path, r) for r in allow_patterns):
  57. continue
  58. if ignore_patterns is not None and any(fnmatch(path, r) for r in ignore_patterns):
  59. continue
  60. yield item
  61. def _add_wildcard_to_directories(pattern: str) -> str:
  62. if pattern[-1] == "/":
  63. return pattern + "*"
  64. return pattern
  65. def get_hf_endpoint() -> str:
  66. return os.environ.get('HF_ENDPOINT', "https://huggingface.co")
  67. def get_hf_home() -> Path:
  68. """Get the Hugging Face home directory."""
  69. return Path(os.environ.get("HF_HOME", Path.home()/".cache"/"huggingface"))
  70. async def get_hf_token():
  71. """Retrieve the Hugging Face token from the user's HF_HOME directory."""
  72. token_path = get_hf_home()/"token"
  73. if await aios.path.exists(token_path):
  74. async with aiofiles.open(token_path, 'r') as f:
  75. return (await f.read()).strip()
  76. return None
  77. async def get_auth_headers():
  78. """Get authentication headers if a token is available."""
  79. token = await get_hf_token()
  80. if token:
  81. return {"Authorization": f"Bearer {token}"}
  82. return {}
  83. def get_repo_root(repo_id: str) -> Path:
  84. """Get the root directory for a given repo ID in the Hugging Face cache."""
  85. sanitized_repo_id = str(repo_id).replace("/", "--")
  86. return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"
  87. async def move_models_to_hf(seed_dir: Union[str, Path]):
  88. """Move model in resources folder of app to .cache/huggingface/hub"""
  89. source_dir = Path(seed_dir)
  90. dest_dir = get_hf_home()/"hub"
  91. await aios.makedirs(dest_dir, exist_ok=True)
  92. async for path in async_iterdir(source_dir):
  93. if path.is_dir() and path.name.startswith("models--"):
  94. dest_path = dest_dir / path.name
  95. if await async_exists(dest_path):
  96. if DEBUG >= 1: print(f"skipping moving {dest_path}. File already exists")
  97. else:
  98. await aios.rename(str(path), str(dest_path))
  99. async def fetch_file_list(session, repo_id, revision, path=""):
  100. api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
  101. url = f"{api_url}/{path}" if path else api_url
  102. headers = await get_auth_headers()
  103. async with session.get(url, headers=headers) as response:
  104. if response.status == 200:
  105. data = await response.json()
  106. files = []
  107. for item in data:
  108. if item["type"] == "file":
  109. files.append({"path": item["path"], "size": item["size"]})
  110. elif item["type"] == "directory":
  111. subfiles = await fetch_file_list(session, repo_id, revision, item["path"])
  112. files.extend(subfiles)
  113. return files
  114. else:
  115. raise Exception(f"Failed to fetch file list: {response.status}")
  116. @retry(
  117. stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=60), retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError, aiohttp.ClientResponseError)), reraise=True
  118. )
  119. async def download_file(
  120. session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[RepoFileProgressCallback] = None, use_range_request: bool = True
  121. ):
  122. base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
  123. url = urljoin(base_url, file_path)
  124. local_path = os.path.join(save_directory, file_path)
  125. await aios.makedirs(os.path.dirname(local_path), exist_ok=True)
  126. # Check if file already exists and get its size
  127. local_file_size = await aios.path.getsize(local_path) if await aios.path.exists(local_path) else 0
  128. headers = await get_auth_headers()
  129. if use_range_request:
  130. headers["Range"] = f"bytes={local_file_size}-"
  131. async with session.get(url, headers=headers) as response:
  132. total_size = int(response.headers.get('Content-Length', 0))
  133. downloaded_size = local_file_size
  134. downloaded_this_session = 0
  135. mode = 'ab' if use_range_request else 'wb'
  136. if downloaded_size == total_size:
  137. if DEBUG >= 2: print(f"File already downloaded: {file_path}")
  138. if progress_callback:
  139. await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
  140. return
  141. if response.status == 200:
  142. # File doesn't support range requests or we're not using them, start from beginning
  143. mode = 'wb'
  144. downloaded_size = 0
  145. elif response.status == 206:
  146. # Partial content, resume download
  147. content_range = response.headers.get('Content-Range', '')
  148. try:
  149. total_size = int(content_range.split('/')[-1])
  150. except ValueError:
  151. if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
  152. return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
  153. elif response.status == 416:
  154. # Range not satisfiable, get the actual file size
  155. content_range = response.headers.get('Content-Range', '')
  156. try:
  157. total_size = int(content_range.split('/')[-1])
  158. if downloaded_size == total_size:
  159. if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}")
  160. if progress_callback:
  161. await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
  162. return
  163. except ValueError:
  164. if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
  165. return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
  166. else:
  167. raise aiohttp.ClientResponseError(response.request_info, response.history, status=response.status, message=f"Failed to download {file_path}: {response.status}")
  168. if downloaded_size == total_size:
  169. print(f"File already downloaded: {file_path}")
  170. if progress_callback:
  171. await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
  172. return
  173. DOWNLOAD_CHUNK_SIZE = 32768
  174. start_time = datetime.now()
  175. async with aiofiles.open(local_path, mode) as f:
  176. async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
  177. await f.write(chunk)
  178. downloaded_size += len(chunk)
  179. downloaded_this_session += len(chunk)
  180. if progress_callback and total_size:
  181. elapsed_time = (datetime.now() - start_time).total_seconds()
  182. speed = int(downloaded_this_session/elapsed_time) if elapsed_time > 0 else 0
  183. remaining_size = total_size - downloaded_size
  184. eta = timedelta(seconds=remaining_size/speed) if speed > 0 else timedelta(0)
  185. status = "in_progress" if downloaded_size < total_size else "complete"
  186. if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}")
  187. await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
  188. if DEBUG >= 2: print(f"Downloaded: {file_path}")
  189. async def resolve_revision_to_commit_hash(repo_id: str, revision: str) -> str:
  190. repo_root = get_repo_root(repo_id)
  191. refs_dir = repo_root/"refs"
  192. refs_file = refs_dir/revision
  193. # Check if we have a cached commit hash
  194. if await aios.path.exists(refs_file):
  195. async with aiofiles.open(refs_file, 'r') as f:
  196. commit_hash = (await f.read()).strip()
  197. if DEBUG >= 2: print(f"Commit hash is already cached at {refs_file}: {commit_hash}")
  198. return commit_hash
  199. # Fetch the commit hash for the given revision
  200. async with aiohttp.ClientSession() as session:
  201. api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/revision/{revision}"
  202. headers = await get_auth_headers()
  203. async with session.get(api_url, headers=headers) as response:
  204. if response.status != 200:
  205. raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
  206. revision_info = await response.json()
  207. commit_hash = revision_info['sha']
  208. # Cache the commit hash
  209. await aios.makedirs(refs_dir, exist_ok=True)
  210. async with aiofiles.open(refs_file, 'w') as f:
  211. await f.write(commit_hash)
  212. return commit_hash
  213. async def download_repo_files(
  214. repo_id: str,
  215. revision: str = "main",
  216. progress_callback: Optional[RepoProgressCallback] = None,
  217. allow_patterns: Optional[Union[List[str], str]] = None,
  218. ignore_patterns: Optional[Union[List[str], str]] = None,
  219. max_parallel_downloads: int = 4
  220. ) -> Path:
  221. repo_root = get_repo_root(repo_id)
  222. snapshots_dir = repo_root/"snapshots"
  223. cachedreqs_dir = repo_root/"cachedreqs"
  224. # Ensure directories exist
  225. await aios.makedirs(snapshots_dir, exist_ok=True)
  226. await aios.makedirs(cachedreqs_dir, exist_ok=True)
  227. # Resolve revision to commit hash
  228. commit_hash = await resolve_revision_to_commit_hash(repo_id, revision)
  229. # Set up the snapshot directory
  230. snapshot_dir = snapshots_dir/commit_hash
  231. await aios.makedirs(snapshot_dir, exist_ok=True)
  232. # Set up the cached file list directory
  233. cached_file_list_dir = cachedreqs_dir/commit_hash
  234. await aios.makedirs(cached_file_list_dir, exist_ok=True)
  235. cached_file_list_path = cached_file_list_dir/"fetch_file_list.json"
  236. async with aiohttp.ClientSession() as session:
  237. # Check if we have a cached file list
  238. if await aios.path.exists(cached_file_list_path):
  239. async with aiofiles.open(cached_file_list_path, 'r') as f:
  240. file_list = json.loads(await f.read())
  241. if DEBUG >= 2: print(f"Using cached file list from {cached_file_list_path}")
  242. else:
  243. file_list = await fetch_file_list(session, repo_id, revision)
  244. # Cache the file list
  245. async with aiofiles.open(cached_file_list_path, 'w') as f:
  246. await f.write(json.dumps(file_list))
  247. if DEBUG >= 2: print(f"Cached file list at {cached_file_list_path}")
  248. filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
  249. total_files = len(filtered_file_list)
  250. total_bytes = sum(file["size"] for file in filtered_file_list)
  251. file_progress: Dict[str, RepoFileProgressEvent] = {
  252. file["path"]: RepoFileProgressEvent(repo_id, revision, file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started")
  253. for file in filtered_file_list
  254. }
  255. start_time = datetime.now()
  256. async def download_with_progress(file_info, progress_state):
  257. local_path = snapshot_dir/file_info["path"]
  258. if await aios.path.exists(local_path) and (await aios.stat(local_path)).st_size == file_info["size"]:
  259. if DEBUG >= 2: print(f"File already fully downloaded: {file_info['path']}")
  260. progress_state['completed_files'] += 1
  261. progress_state['downloaded_bytes'] += file_info["size"]
  262. file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], 0, file_info["size"], 0, timedelta(0), "complete")
  263. if progress_callback:
  264. elapsed_time = (datetime.now() - start_time).total_seconds()
  265. overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
  266. remaining_bytes = total_bytes - progress_state['downloaded_bytes']
  267. overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
  268. status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
  269. await progress_callback(
  270. RepoProgressEvent(
  271. repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
  272. overall_eta, file_progress, status
  273. )
  274. )
  275. return
  276. async def file_progress_callback(event: RepoFileProgressEvent):
  277. progress_state['downloaded_bytes'] += event.downloaded - file_progress[event.file_path].downloaded
  278. progress_state['downloaded_bytes_this_session'] += event.downloaded_this_session - file_progress[event.file_path].downloaded_this_session
  279. file_progress[event.file_path] = event
  280. if progress_callback:
  281. elapsed_time = (datetime.now() - start_time).total_seconds()
  282. overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
  283. remaining_bytes = total_bytes - progress_state['downloaded_bytes']
  284. overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
  285. status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete"
  286. await progress_callback(
  287. RepoProgressEvent(
  288. repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
  289. overall_eta, file_progress, status
  290. )
  291. )
  292. await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
  293. progress_state['completed_files'] += 1
  294. file_progress[
  295. file_info["path"]
  296. ] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete")
  297. if progress_callback:
  298. elapsed_time = (datetime.now() - start_time).total_seconds()
  299. overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
  300. remaining_bytes = total_bytes - progress_state['downloaded_bytes']
  301. overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
  302. status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
  303. await progress_callback(
  304. RepoProgressEvent(
  305. repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
  306. overall_eta, file_progress, status
  307. )
  308. )
  309. progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0}
  310. semaphore = asyncio.Semaphore(max_parallel_downloads)
  311. async def download_with_semaphore(file_info):
  312. async with semaphore:
  313. await download_with_progress(file_info, progress_state)
  314. tasks = [asyncio.create_task(download_with_semaphore(file_info)) for file_info in filtered_file_list]
  315. await asyncio.gather(*tasks)
  316. return snapshot_dir
  317. async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[str, str]]:
  318. """
  319. Retrieve the weight map from the model.safetensors.index.json file.
  320. Args:
  321. repo_id (str): The Hugging Face repository ID.
  322. revision (str): The revision of the repository to use.
  323. Returns:
  324. Optional[Dict[str, str]]: The weight map if it exists, otherwise None.
  325. """
  326. # Download the index file
  327. await download_repo_files(repo_id=repo_id, revision=revision, allow_patterns="model.safetensors.index.json")
  328. # Check if the file exists
  329. repo_root = get_repo_root(repo_id)
  330. commit_hash = await resolve_revision_to_commit_hash(repo_id, revision)
  331. snapshot_dir = repo_root/"snapshots"/commit_hash
  332. index_file = next((f for f in await aios.listdir(snapshot_dir) if f.endswith("model.safetensors.index.json")), None)
  333. if index_file:
  334. index_file_path = snapshot_dir/index_file
  335. if await aios.path.exists(index_file_path):
  336. async with aiofiles.open(index_file_path, 'r') as f:
  337. index_data = json.loads(await f.read())
  338. return index_data.get("weight_map")
  339. return None
  340. def extract_layer_num(tensor_name: str) -> Optional[int]:
  341. # This is a simple example and might need to be adjusted based on the actual naming convention
  342. parts = tensor_name.split('.')
  343. for part in parts:
  344. if part.isdigit():
  345. return int(part)
  346. return None
  347. def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
  348. default_patterns = set(["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"])
  349. shard_specific_patterns = set()
  350. if weight_map:
  351. for tensor_name, filename in weight_map.items():
  352. layer_num = extract_layer_num(tensor_name)
  353. if layer_num is not None and shard.start_layer <= layer_num <= shard.end_layer:
  354. shard_specific_patterns.add(filename)
  355. sorted_file_names = sorted(weight_map.values())
  356. if shard.is_first_layer():
  357. shard_specific_patterns.add(sorted_file_names[0])
  358. elif shard.is_last_layer():
  359. shard_specific_patterns.add(sorted_file_names[-1])
  360. else:
  361. shard_specific_patterns = set(["*.safetensors"])
  362. if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
  363. return list(default_patterns | shard_specific_patterns)
  364. async def has_hf_home_read_access() -> bool:
  365. hf_home = get_hf_home()
  366. try: return await aios.access(hf_home, os.R_OK)
  367. except OSError: return False
  368. async def has_hf_home_write_access() -> bool:
  369. hf_home = get_hf_home()
  370. try: return await aios.access(hf_home, os.W_OK)
  371. except OSError: return False