Bläddra i källkod

Merge pull request #138 from exo-explore/better_downloads

add --max-parallel-downloads flag that limits the number of downloads fixes #137, use async for all file ops, cache fetch_file_list, cache commit hash, quickly check file sizes on disk before making requests
Alex Cheema 1 år sedan
förälder
incheckning
781a71ccff
3 ändrade filer med 97 tillägg och 43 borttagningar
  1. 90 39
      exo/download/hf/hf_helpers.py
  2. 4 2
      exo/download/hf/hf_shard_download.py
  3. 3 2
      main.py

+ 90 - 39
exo/download/hf/hf_helpers.py

@@ -12,6 +12,8 @@ from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_excep
 from exo.helpers import DEBUG
 from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback
 from exo.inference.shard import Shard
+import aiofiles
+from aiofiles import os as aios
 
 T = TypeVar("T")
 def filter_repo_objects(
@@ -56,16 +58,17 @@ def get_hf_home() -> Path:
     """Get the Hugging Face home directory."""
     return Path(os.environ.get("HF_HOME", Path.home() / ".cache" / "huggingface"))
 
-def get_hf_token():
+async def get_hf_token():
     """Retrieve the Hugging Face token from the user's HF_HOME directory."""
     token_path = get_hf_home() / "token"
-    if token_path.exists():
-        return token_path.read_text().strip()
+    if await aios.path.exists(token_path):
+        async with aiofiles.open(token_path, 'r') as f:
+            return (await f.read()).strip()
     return None
 
-def get_auth_headers():
+async def get_auth_headers():
     """Get authentication headers if a token is available."""
-    token = get_hf_token()
+    token = await get_hf_token()
     if token:
         return {"Authorization": f"Bearer {token}"}
     return {}
@@ -79,7 +82,7 @@ async def fetch_file_list(session, repo_id, revision, path=""):
     api_url = f"https://huggingface.co/api/models/{repo_id}/tree/{revision}"
     url = f"{api_url}/{path}" if path else api_url
 
-    headers = get_auth_headers()
+    headers = await get_auth_headers()
     async with session.get(url, headers=headers) as response:
         if response.status == 200:
             data = await response.json()
@@ -106,12 +109,12 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
     url = urljoin(base_url, file_path)
     local_path = os.path.join(save_directory, file_path)
 
-    os.makedirs(os.path.dirname(local_path), exist_ok=True)
+    await aios.makedirs(os.path.dirname(local_path), exist_ok=True)
 
     # Check if file already exists and get its size
-    local_file_size = os.path.getsize(local_path) if os.path.exists(local_path) else 0
+    local_file_size = await aios.path.getsize(local_path) if await aios.path.exists(local_path) else 0
 
-    headers = get_auth_headers()
+    headers = await get_auth_headers()
     if use_range_request:
         headers["Range"] = f"bytes={local_file_size}-"
 
@@ -162,9 +165,9 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
 
         DOWNLOAD_CHUNK_SIZE = 32768
         start_time = datetime.now()
-        with open(local_path, mode) as f:
+        async with aiofiles.open(local_path, mode) as f:
             async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
-                f.write(chunk)
+                await f.write(chunk)
                 downloaded_size += len(chunk)
                 downloaded_this_session += len(chunk)
                 if progress_callback and total_size:
@@ -177,34 +180,60 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
                     await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
         if DEBUG >= 2: print(f"Downloaded: {file_path}")
 
-async def download_repo_files(repo_id: str, revision: str = "main", progress_callback: Optional[RepoProgressCallback] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None) -> Path:
+async def download_repo_files(repo_id: str, revision: str = "main", progress_callback: Optional[RepoProgressCallback] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None, max_parallel_downloads: int = 4) -> Path:
     repo_root = get_repo_root(repo_id)
     refs_dir = repo_root / "refs"
     snapshots_dir = repo_root / "snapshots"
+    cachedreqs_dir = repo_root / "cachedreqs"
 
     # Ensure directories exist
-    refs_dir.mkdir(parents=True, exist_ok=True)
-    snapshots_dir.mkdir(parents=True, exist_ok=True)
+    await aios.makedirs(refs_dir, exist_ok=True)
+    await aios.makedirs(snapshots_dir, exist_ok=True)
+    await aios.makedirs(cachedreqs_dir, exist_ok=True)
+
+    # Check if we have a cached commit hash
+    refs_file = refs_dir / revision
+    if await aios.path.exists(refs_file):
+        async with aiofiles.open(refs_file, 'r') as f:
+            commit_hash = (await f.read()).strip()
+            if DEBUG >= 2: print(f"Commit hash is already hashed at {refs_file}: {commit_hash}")
+    else:
+        async with aiohttp.ClientSession() as session:
+            # Fetch the commit hash for the given revision
+            api_url = f"https://huggingface.co/api/models/{repo_id}/revision/{revision}"
+            headers = await get_auth_headers()
+            async with session.get(api_url, headers=headers) as response:
+                if response.status != 200:
+                    raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
+                revision_info = await response.json()
+                commit_hash = revision_info['sha']
+
+            # Cache the commit hash
+            async with aiofiles.open(refs_file, 'w') as f:
+                await f.write(commit_hash)
+
+    # Set up the snapshot directory
+    snapshot_dir = snapshots_dir / commit_hash
+    await aios.makedirs(snapshot_dir, exist_ok=True)
+
+    # Set up the cached file list directory
+    cached_file_list_dir = cachedreqs_dir / commit_hash
+    await aios.makedirs(cached_file_list_dir, exist_ok=True)
+    cached_file_list_path = cached_file_list_dir / "fetch_file_list.json"
 
     async with aiohttp.ClientSession() as session:
-        # Fetch the commit hash for the given revision
-        api_url = f"https://huggingface.co/api/models/{repo_id}/revision/{revision}"
-        headers = get_auth_headers()
-        async with session.get(api_url, headers=headers) as response:
-            if response.status != 200:
-                raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
-            revision_info = await response.json()
-            commit_hash = revision_info['sha']
-
-        # Write the commit hash to the refs file
-        refs_file = refs_dir / revision
-        refs_file.write_text(commit_hash)
-
-        # Set up the snapshot directory
-        snapshot_dir = snapshots_dir / commit_hash
-        snapshot_dir.mkdir(exist_ok=True)
-
-        file_list = await fetch_file_list(session, repo_id, revision)
+        # Check if we have a cached file list
+        if await aios.path.exists(cached_file_list_path):
+            async with aiofiles.open(cached_file_list_path, 'r') as f:
+                file_list = json.loads(await f.read())
+            if DEBUG >= 2: print(f"Using cached file list from {cached_file_list_path}")
+        else:
+            file_list = await fetch_file_list(session, repo_id, revision)
+            # Cache the file list
+            async with aiofiles.open(cached_file_list_path, 'w') as f:
+                await f.write(json.dumps(file_list))
+            if DEBUG >= 2: print(f"Cached file list at {cached_file_list_path}")
+
         filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
         total_files = len(filtered_file_list)
         total_bytes = sum(file["size"] for file in filtered_file_list)
@@ -212,6 +241,21 @@ async def download_repo_files(repo_id: str, revision: str = "main", progress_cal
         start_time = datetime.now()
 
         async def download_with_progress(file_info, progress_state):
+            local_path = snapshot_dir / file_info["path"]
+            if await aios.path.exists(local_path) and (await aios.stat(local_path)).st_size == file_info["size"]:
+                if DEBUG >= 2: print(f"File already fully downloaded: {file_info['path']}")
+                progress_state['completed_files'] += 1
+                progress_state['downloaded_bytes'] += file_info["size"]
+                file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], 0, file_info["size"], 0, timedelta(0), "complete")
+                if progress_callback:
+                    elapsed_time = (datetime.now() - start_time).total_seconds()
+                    overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
+                    remaining_bytes = total_bytes - progress_state['downloaded_bytes']
+                    overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
+                    status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
+                    await progress_callback(RepoProgressEvent(repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
+                return
+
             async def file_progress_callback(event: RepoFileProgressEvent):
                 progress_state['downloaded_bytes'] += event.downloaded - file_progress[event.file_path].downloaded
                 progress_state['downloaded_bytes_this_session'] += event.downloaded_this_session - file_progress[event.file_path].downloaded_this_session
@@ -236,7 +280,12 @@ async def download_repo_files(repo_id: str, revision: str = "main", progress_cal
                 await progress_callback(RepoProgressEvent(repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
 
         progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0}
-        tasks = [download_with_progress(file_info, progress_state) for file_info in filtered_file_list]
+
+        semaphore = asyncio.Semaphore(max_parallel_downloads)
+        async def download_with_semaphore(file_info):
+            async with semaphore:
+                await download_with_progress(file_info, progress_state)
+        tasks = [asyncio.create_task(download_with_semaphore(file_info)) for file_info in filtered_file_list]
         await asyncio.gather(*tasks)
 
     return snapshot_dir
@@ -263,12 +312,14 @@ async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[
     # Check if the file exists
     repo_root = get_repo_root(repo_id)
     snapshot_dir = repo_root / "snapshots"
-    index_file = next(snapshot_dir.glob("*/model.safetensors.index.json"), None)
-
-    if index_file and index_file.exists():
-        with open(index_file, 'r') as f:
-            index_data = json.load(f)
-        return index_data.get("weight_map")
+    index_file = next((f for f in await aios.listdir(snapshot_dir) if f.endswith("model.safetensors.index.json")), None)
+
+    if index_file:
+        index_file_path = snapshot_dir / index_file
+        if await aios.path.exists(index_file_path):
+            async with aiofiles.open(index_file_path, 'r') as f:
+                index_data = json.loads(await f.read())
+            return index_data.get("weight_map")
 
     return None
 

+ 4 - 2
exo/download/hf/hf_shard_download.py

@@ -9,8 +9,9 @@ from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, g
 from exo.helpers import AsyncCallbackSystem, DEBUG
 
 class HFShardDownloader(ShardDownloader):
-    def __init__(self, quick_check: bool = False):
+    def __init__(self, quick_check: bool = False, max_parallel_downloads: int = 4):
         self.quick_check = quick_check
+        self.max_parallel_downloads = max_parallel_downloads
         self.active_downloads: Dict[Shard, asyncio.Task] = {}
         self.completed_downloads: Dict[Shard, Path] = {}
         self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
@@ -69,7 +70,8 @@ class HFShardDownloader(ShardDownloader):
         return await download_repo_files(
             repo_id=shard.model_id,
             progress_callback=wrapped_progress_callback,
-            allow_patterns=allow_patterns
+            allow_patterns=allow_patterns,
+            max_parallel_downloads=self.max_parallel_downloads
         )
 
     @property

+ 3 - 2
main.py

@@ -20,7 +20,8 @@ parser.add_argument("--node-id", type=str, default=None, help="Node ID")
 parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
 parser.add_argument("--node-port", type=int, default=None, help="Node port")
 parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
-parser.add_argument("--download-quick-check", action="store_true", help="Quick check local path for shard download")
+parser.add_argument("--download-quick-check", action="store_true", help="Quick check local path for model shards download")
+parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download")
 parser.add_argument("--prometheus-client-port", type=int, default=None, help="Prometheus client port")
 parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
 parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery timeout in seconds")
@@ -37,7 +38,7 @@ print_yellow_exo()
 system_info = get_system_info()
 print(f"Detected system: {system_info}")
 
-shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check)
+shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check, max_parallel_downloads=args.max_parallel_downloads)
 inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
 inference_engine = get_inference_engine(inference_engine_name, shard_downloader)
 print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")