shard_download.py 1.2 KB

123456789101112131415161718192021222324252627282930313233343536
  1. from abc import ABC, abstractmethod
  2. from typing import Optional, Tuple
  3. from pathlib import Path
  4. from exo.inference.shard import Shard
  5. from exo.download.download_progress import RepoProgressEvent
  6. from exo.helpers import AsyncCallbackSystem
  7. class ShardDownloader(ABC):
  8. @abstractmethod
  9. async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
  10. """
  11. Ensures that the shard is downloaded.
  12. Does not allow multiple overlapping downloads at once.
  13. If you try to download a Shard which overlaps a Shard that is already being downloaded,
  14. the download will be cancelled and a new download will start.
  15. Args:
  16. shard (Shard): The shard to download.
  17. inference_engine_name (str): The inference engine used on the node hosting the shard
  18. """
  19. pass
  20. @property
  21. @abstractmethod
  22. def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
  23. pass
  24. class NoopShardDownloader(ShardDownloader):
  25. async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
  26. return Path("/tmp/noop_shard")
  27. @property
  28. def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
  29. return AsyncCallbackSystem()