hf_shard_download.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import asyncio
  2. import traceback
  3. from pathlib import Path
  4. from typing import Dict, List, Tuple
  5. from exo.inference.shard import Shard
  6. from exo.download.shard_download import ShardDownloader
  7. from exo.download.download_progress import RepoProgressEvent
  8. from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_weight_map, get_allow_patterns, get_repo_root
  9. from exo.helpers import AsyncCallbackSystem, DEBUG
  10. class HFShardDownloader(ShardDownloader):
  11. def __init__(self, quick_check: bool = False, max_parallel_downloads: int = 4):
  12. self.quick_check = quick_check
  13. self.max_parallel_downloads = max_parallel_downloads
  14. self.active_downloads: Dict[Shard, asyncio.Task] = {}
  15. self.completed_downloads: Dict[Shard, Path] = {}
  16. self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
  17. async def ensure_shard(self, shard: Shard) -> Path:
  18. if shard in self.completed_downloads:
  19. return self.completed_downloads[shard]
  20. if self.quick_check:
  21. repo_root = get_repo_root(shard.model_id)
  22. snapshots_dir = repo_root/"snapshots"
  23. if snapshots_dir.exists():
  24. most_recent_dir = max(snapshots_dir.iterdir(), key=lambda x: x.stat().st_mtime)
  25. return most_recent_dir
  26. # If a download on this shard is already in progress, keep that one
  27. for active_shard in self.active_downloads:
  28. if active_shard == shard:
  29. if DEBUG >= 2: print(f"Download already in progress for {shard}. Keeping that one.")
  30. return await self.active_downloads[shard]
  31. # Cancel any downloads for this model_id on a different shard
  32. existing_active_shards = [active_shard for active_shard in self.active_downloads.keys() if active_shard.model_id == shard.model_id]
  33. for active_shard in existing_active_shards:
  34. if DEBUG >= 2: print(f"Cancelling download for {active_shard} (replacing with {shard})")
  35. task = self.active_downloads[active_shard]
  36. task.cancel()
  37. try:
  38. await task
  39. except asyncio.CancelledError:
  40. pass # This is expected when cancelling a task
  41. except Exception as e:
  42. if DEBUG >= 2: print(f"Error in cancelling download {active_shard}: {e}")
  43. traceback.print_exc()
  44. self.active_downloads = {active_shard: task for active_shard, task in self.active_downloads.items() if active_shard.model_id != shard.model_id}
  45. # Start new download
  46. download_task = asyncio.create_task(self._download_shard(shard))
  47. self.active_downloads[shard] = download_task
  48. try:
  49. path = await download_task
  50. self.completed_downloads[shard] = path
  51. return path
  52. finally:
  53. # Ensure the task is removed even if an exception occurs
  54. print(f"Removing download task for {shard}: {shard in self.active_downloads}")
  55. if shard in self.active_downloads:
  56. self.active_downloads.pop(shard)
  57. async def _download_shard(self, shard: Shard) -> Path:
  58. async def wrapped_progress_callback(event: RepoProgressEvent):
  59. self._on_progress.trigger_all(shard, event)
  60. weight_map = await get_weight_map(shard.model_id)
  61. allow_patterns = get_allow_patterns(weight_map, shard)
  62. return await download_repo_files(repo_id=shard.model_id, progress_callback=wrapped_progress_callback, allow_patterns=allow_patterns, max_parallel_downloads=self.max_parallel_downloads)
  63. @property
  64. def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
  65. return self._on_progress