|
@@ -5,11 +5,12 @@ from typing import Dict, List, Tuple
|
|
|
from exo.inference.shard import Shard
|
|
|
from exo.download.shard_download import ShardDownloader
|
|
|
from exo.download.download_progress import RepoProgressEvent
|
|
|
-from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_weight_map, get_allow_patterns
|
|
|
+from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_weight_map, get_allow_patterns, get_repo_root
|
|
|
from exo.helpers import AsyncCallbackSystem, DEBUG
|
|
|
|
|
|
class HFShardDownloader(ShardDownloader):
|
|
|
- def __init__(self):
|
|
|
+ def __init__(self, quick_check: bool = False):
|
|
|
+ self.quick_check = quick_check
|
|
|
self.active_downloads: Dict[Shard, asyncio.Task] = {}
|
|
|
self.completed_downloads: Dict[Shard, Path] = {}
|
|
|
self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
|
|
@@ -17,6 +18,12 @@ class HFShardDownloader(ShardDownloader):
|
|
|
async def ensure_shard(self, shard: Shard) -> Path:
|
|
|
if shard in self.completed_downloads:
|
|
|
return self.completed_downloads[shard]
|
|
|
+ if self.quick_check:
|
|
|
+ repo_root = get_repo_root(shard.model_id)
|
|
|
+ snapshots_dir = repo_root / "snapshots"
|
|
|
+ if snapshots_dir.exists():
|
|
|
+ most_recent_dir = max(snapshots_dir.iterdir(), key=lambda x: x.stat().st_mtime)
|
|
|
+ return most_recent_dir
|
|
|
|
|
|
# If a download on this shard is already in progress, keep that one
|
|
|
for active_shard in self.active_downloads:
|