Browse Source

add --download-quick-check flag to bypass the hf api calls / remote file checks

Alex Cheema 1 year ago
parent
commit
e6902b2fcf
2 changed files with 11 additions and 3 deletions
  1. 9 2
      exo/download/hf/hf_shard_download.py
  2. 2 1
      main.py

+ 9 - 2
exo/download/hf/hf_shard_download.py

@@ -5,11 +5,12 @@ from typing import Dict, List, Tuple
 from exo.inference.shard import Shard
 from exo.download.shard_download import ShardDownloader
 from exo.download.download_progress import RepoProgressEvent
-from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_weight_map, get_allow_patterns
+from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_weight_map, get_allow_patterns, get_repo_root
 from exo.helpers import AsyncCallbackSystem, DEBUG
 
 class HFShardDownloader(ShardDownloader):
-    def __init__(self):
+    def __init__(self, quick_check: bool = False):
+        self.quick_check = quick_check
         self.active_downloads: Dict[Shard, asyncio.Task] = {}
         self.completed_downloads: Dict[Shard, Path] = {}
         self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
@@ -17,6 +18,12 @@ class HFShardDownloader(ShardDownloader):
     async def ensure_shard(self, shard: Shard) -> Path:
         if shard in self.completed_downloads:
             return self.completed_downloads[shard]
+        if self.quick_check:
+            repo_root = get_repo_root(shard.model_id)
+            snapshots_dir = repo_root / "snapshots"
+            if snapshots_dir.exists():
+                most_recent_dir = max(snapshots_dir.iterdir(), key=lambda x: x.stat().st_mtime)
+                return most_recent_dir
 
         # If a download on this shard is already in progress, keep that one
         for active_shard in self.active_downloads:

+ 2 - 1
main.py

@@ -20,6 +20,7 @@ parser.add_argument("--node-id", type=str, default=None, help="Node ID")
 parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
 parser.add_argument("--node-port", type=int, default=None, help="Node port")
 parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
+parser.add_argument("--download-quick-check", action="store_true", help="Quick check local path for shard download")
 parser.add_argument("--prometheus-client-port", type=int, default=None, help="Prometheus client port")
 parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
 parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery timeout in seconds")
@@ -36,7 +37,7 @@ print_yellow_exo()
 system_info = get_system_info()
 print(f"Detected system: {system_info}")
 
-shard_downloader: ShardDownloader = HFShardDownloader()
+shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check)
 inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
 inference_engine = get_inference_engine(inference_engine_name, shard_downloader)
 print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")