Browse Source

make a separate ShardDownloader abstract class w HFShardDownloader. this opens up plugging in different methods of downloading model shards e.g. #79 / #16

Alex Cheema 8 months ago
parent
commit
476a714bbb

+ 1 - 4
exo/api/chatgpt_api.py

@@ -103,10 +103,7 @@ async def resolve_tokenizer(model_id: str):
     if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer. Error: {e}")
     if DEBUG >= 4: print(traceback.format_exc())
 
-  if DEBUG >= 4: print(f"Trying mlx tokenizer for {model_id}")
-  from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
-
-  return load_tokenizer(await get_model_path(model_id))
+  raise ValueError(f"[TODO] Unsupported model: {model_id}")
 
 
 def generate_completion(

+ 74 - 0
exo/download/download_progress.py

@@ -0,0 +1,74 @@
+from typing import Dict, Callable, Coroutine, Any, Literal
+from dataclasses import dataclass
+from datetime import timedelta
+
+@dataclass
+class RepoFileProgressEvent:
+    file_path: str
+    downloaded: int
+    downloaded_this_session: int
+    total: int
+    speed: int
+    eta: timedelta
+    status: Literal["not_started", "in_progress", "complete"]
+
+    def to_dict(self):
+        return {
+            "file_path": self.file_path,
+            "downloaded": self.downloaded,
+            "downloaded_this_session": self.downloaded_this_session,
+            "total": self.total,
+            "speed": self.speed,
+            "eta": self.eta.total_seconds(),
+            "status": self.status
+        }
+
+    @classmethod
+    def from_dict(cls, data):
+        # Convert eta from seconds back to timedelta
+        if 'eta' in data:
+            data['eta'] = timedelta(seconds=data['eta'])
+        return cls(**data)
+
+@dataclass
+class RepoProgressEvent:
+    completed_files: int
+    total_files: int
+    downloaded_bytes: int
+    downloaded_bytes_this_session: int
+    total_bytes: int
+    overall_speed: int
+    overall_eta: timedelta
+    file_progress: Dict[str, RepoFileProgressEvent]
+    status: Literal["not_started", "in_progress", "complete"]
+
+    def to_dict(self):
+        return {
+            "completed_files": self.completed_files,
+            "total_files": self.total_files,
+            "downloaded_bytes": self.downloaded_bytes,
+            "downloaded_bytes_this_session": self.downloaded_bytes_this_session,
+            "total_bytes": self.total_bytes,
+            "overall_speed": self.overall_speed,
+            "overall_eta": self.overall_eta.total_seconds(),
+            "file_progress": {k: v.to_dict() for k, v in self.file_progress.items()},
+            "status": self.status
+        }
+
+    @classmethod
+    def from_dict(cls, data):
+        # Convert overall_eta from seconds back to timedelta
+        if 'overall_eta' in data:
+            data['overall_eta'] = timedelta(seconds=data['overall_eta'])
+
+        # Parse file_progress
+        if 'file_progress' in data:
+            data['file_progress'] = {
+                k: RepoFileProgressEvent.from_dict(v)
+                for k, v in data['file_progress'].items()
+            }
+
+        return cls(**data)
+
+RepoFileProgressCallback = Callable[[RepoFileProgressEvent], Coroutine[Any, Any, None]]
+RepoProgressCallback = Callable[[RepoProgressEvent], Coroutine[Any, Any, None]]

+ 52 - 89
exo/inference/hf_helpers.py → exo/download/hf/hf_helpers.py

@@ -1,5 +1,6 @@
 import asyncio
 import aiohttp
+import json
 import os
 from urllib.parse import urljoin
 from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
@@ -7,9 +8,9 @@ from datetime import datetime, timedelta
 from fnmatch import fnmatch
 from pathlib import Path
 from typing import Generator, Iterable, TypeVar, TypedDict
-from dataclasses import dataclass
 from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
 from exo.helpers import DEBUG
+from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback
 
 T = TypeVar("T")
 def filter_repo_objects(
@@ -21,10 +22,8 @@ def filter_repo_objects(
 ) -> Generator[T, None, None]:
     if isinstance(allow_patterns, str):
         allow_patterns = [allow_patterns]
-
     if isinstance(ignore_patterns, str):
         ignore_patterns = [ignore_patterns]
-
     if allow_patterns is not None:
         allow_patterns = [_add_wildcard_to_directories(p) for p in allow_patterns]
     if ignore_patterns is not None:
@@ -37,18 +36,14 @@ def filter_repo_objects(
             if isinstance(item, Path):
                 return str(item)
             raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.")
-
         key = _identity
 
     for item in items:
         path = key(item)
-
         if allow_patterns is not None and not any(fnmatch(path, r) for r in allow_patterns):
             continue
-
         if ignore_patterns is not None and any(fnmatch(path, r) for r in ignore_patterns):
             continue
-
         yield item
 
 def _add_wildcard_to_directories(pattern: str) -> str:
@@ -99,84 +94,13 @@ async def fetch_file_list(session, repo_id, revision, path=""):
             raise Exception(f"Failed to fetch file list: {response.status}")
 
 
-@dataclass
-class HFRepoFileProgressEvent:
-    file_path: str
-    downloaded: int
-    downloaded_this_session: int
-    total: int
-    speed: int
-    eta: timedelta
-    status: Literal["not_started", "in_progress", "complete"]
-
-    def to_dict(self):
-        return {
-            "file_path": self.file_path,
-            "downloaded": self.downloaded,
-            "downloaded_this_session": self.downloaded_this_session,
-            "total": self.total,
-            "speed": self.speed,
-            "eta": self.eta.total_seconds(),
-            "status": self.status
-        }
-
-    @classmethod
-    def from_dict(cls, data):
-        # Convert eta from seconds back to timedelta
-        if 'eta' in data:
-            data['eta'] = timedelta(seconds=data['eta'])
-        return cls(**data)
-
-@dataclass
-class HFRepoProgressEvent:
-    completed_files: int
-    total_files: int
-    downloaded_bytes: int
-    downloaded_bytes_this_session: int
-    total_bytes: int
-    overall_speed: int
-    overall_eta: timedelta
-    file_progress: Dict[str, HFRepoFileProgressEvent]
-    status: Literal["not_started", "in_progress", "complete"]
-
-    def to_dict(self):
-        return {
-            "completed_files": self.completed_files,
-            "total_files": self.total_files,
-            "downloaded_bytes": self.downloaded_bytes,
-            "downloaded_bytes_this_session": self.downloaded_bytes_this_session,
-            "total_bytes": self.total_bytes,
-            "overall_speed": self.overall_speed,
-            "overall_eta": self.overall_eta.total_seconds(),
-            "file_progress": {k: v.to_dict() for k, v in self.file_progress.items()},
-            "status": self.status
-        }
-
-    @classmethod
-    def from_dict(cls, data):
-        # Convert overall_eta from seconds back to timedelta
-        if 'overall_eta' in data:
-            data['overall_eta'] = timedelta(seconds=data['overall_eta'])
-
-        # Parse file_progress
-        if 'file_progress' in data:
-            data['file_progress'] = {
-                k: HFRepoFileProgressEvent.from_dict(v)
-                for k, v in data['file_progress'].items()
-            }
-
-        return cls(**data)
-
-HFRepoFileProgressCallback = Callable[[HFRepoFileProgressEvent], Coroutine[Any, Any, None]]
-HFRepoProgressCallback = Callable[[HFRepoProgressEvent], Coroutine[Any, Any, None]]
-
 @retry(
     stop=stop_after_attempt(5),
     wait=wait_exponential(multiplier=1, min=4, max=60),
     retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError, aiohttp.ClientResponseError)),
     reraise=True
 )
-async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[HFRepoFileProgressCallback] = None, use_range_request: bool = True):
+async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[RepoFileProgressCallback] = None, use_range_request: bool = True):
     base_url = f"https://huggingface.co/{repo_id}/resolve/{revision}/"
     url = urljoin(base_url, file_path)
     local_path = os.path.join(save_directory, file_path)
@@ -198,7 +122,7 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
         if downloaded_size == total_size:
             if DEBUG >= 2: print(f"File already downloaded: {file_path}")
             if progress_callback:
-                await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
+                await progress_callback(RepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
             return
 
         if response.status == 200:
@@ -221,7 +145,7 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
                 if downloaded_size == total_size:
                     if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}")
                     if progress_callback:
-                        await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
+                        await progress_callback(RepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
                     return
             except ValueError:
                 if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
@@ -232,7 +156,7 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
         if downloaded_size == total_size:
             print(f"File already downloaded: {file_path}")
             if progress_callback:
-                await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
+                await progress_callback(RepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
             return
 
         DOWNLOAD_CHUNK_SIZE = 32768
@@ -249,10 +173,10 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
                     eta = timedelta(seconds=remaining_size / speed) if speed > 0 else timedelta(0)
                     status = "in_progress" if downloaded_size < total_size else "complete"
                     if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}")
-                    await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
+                    await progress_callback(RepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
         if DEBUG >= 2: print(f"Downloaded: {file_path}")
 
-async def download_all_files(repo_id: str, revision: str = "main", progress_callback: Optional[HFRepoProgressCallback] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None):
+async def download_repo_files(repo_id: str, revision: str = "main", progress_callback: Optional[RepoProgressCallback] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None) -> Path:
     repo_root = get_repo_root(repo_id)
     refs_dir = repo_root / "refs"
     snapshots_dir = repo_root / "snapshots"
@@ -283,11 +207,11 @@ async def download_all_files(repo_id: str, revision: str = "main", progress_call
         filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
         total_files = len(filtered_file_list)
         total_bytes = sum(file["size"] for file in filtered_file_list)
-        file_progress: Dict[str, HFRepoFileProgressEvent] = {file["path"]: HFRepoFileProgressEvent(file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started") for file in filtered_file_list}
+        file_progress: Dict[str, RepoFileProgressEvent] = {file["path"]: RepoFileProgressEvent(file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started") for file in filtered_file_list}
         start_time = datetime.now()
 
         async def download_with_progress(file_info, progress_state):
-            async def file_progress_callback(event: HFRepoFileProgressEvent):
+            async def file_progress_callback(event: RepoFileProgressEvent):
                 progress_state['downloaded_bytes'] += event.downloaded - file_progress[event.file_path].downloaded
                 progress_state['downloaded_bytes_this_session'] += event.downloaded_this_session - file_progress[event.file_path].downloaded_this_session
                 file_progress[event.file_path] = event
@@ -297,21 +221,60 @@ async def download_all_files(repo_id: str, revision: str = "main", progress_call
                     remaining_bytes = total_bytes - progress_state['downloaded_bytes']
                     overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
                     status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete"
-                    await progress_callback(HFRepoProgressEvent(progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
+                    await progress_callback(RepoProgressEvent(progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
 
             await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
             progress_state['completed_files'] += 1
-            file_progress[file_info["path"]] = HFRepoFileProgressEvent(file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete")
+            file_progress[file_info["path"]] = RepoFileProgressEvent(file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete")
             if progress_callback:
                 elapsed_time = (datetime.now() - start_time).total_seconds()
                 overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
                 remaining_bytes = total_bytes - progress_state['downloaded_bytes']
                 overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
                 status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
-                await progress_callback(HFRepoProgressEvent(progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
+                await progress_callback(RepoProgressEvent(progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
 
         progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0}
         tasks = [download_with_progress(file_info, progress_state) for file_info in filtered_file_list]
         await asyncio.gather(*tasks)
 
     return snapshot_dir
+
+async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[str, str]]:
+    """
+    Retrieve the weight map from the model.safetensors.index.json file.
+
+    Args:
+        repo_id (str): The Hugging Face repository ID.
+        revision (str): The revision of the repository to use.
+
+    Returns:
+        Optional[Dict[str, str]]: The weight map if it exists, otherwise None.
+    """
+
+    # Download the index file
+    await download_repo_files(
+        repo_id=repo_id,
+        revision=revision,
+        allow_patterns="model.safetensors.index.json"
+    )
+
+    # Check if the file exists
+    repo_root = get_repo_root(repo_id)
+    snapshot_dir = repo_root / "snapshots"
+    index_file = next(snapshot_dir.glob("*/model.safetensors.index.json"), None)
+
+    if index_file and index_file.exists():
+        with open(index_file, 'r') as f:
+            index_data = json.load(f)
+        return index_data.get("weight_map")
+
+    return None
+
+def extract_layer_num(tensor_name: str) -> Optional[int]:
+    # This is a simple example and might need to be adjusted based on the actual naming convention
+    parts = tensor_name.split('.')
+    for part in parts:
+        if part.isdigit():
+            return int(part)
+    return None

+ 76 - 0
exo/download/hf/hf_shard_download.py

@@ -0,0 +1,76 @@
+import asyncio
+from pathlib import Path
+from typing import Dict, List, Tuple
+from exo.inference.shard import Shard
+from exo.download.shard_download import ShardDownloader
+from exo.download.download_progress import RepoProgressEvent
+from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_repo_root, get_weight_map, extract_layer_num
+from exo.helpers import AsyncCallbackSystem
+
+class HFShardDownloader(ShardDownloader):
+    def __init__(self):
+        self.active_downloads: List[Tuple[Shard, asyncio.Task]] = []
+        self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
+
+    async def ensure_shard(self, shard: Shard) -> Path:
+        # Cancel any overlapping downloads
+        to_remove = []
+        for active_shard, task in self.active_downloads:
+            if shard.overlaps(active_shard):
+                task.cancel()
+                try:
+                    await task
+                except asyncio.CancelledError:
+                    pass  # This is expected when cancelling a task
+                to_remove.append((active_shard, task))
+
+        # Remove cancelled downloads from the list
+        for item in to_remove:
+            self.active_downloads.remove(item)
+
+        # Start new download
+        download_task = asyncio.create_task(self._download_shard(shard))
+        self.active_downloads.append((shard, download_task))
+
+        try:
+            return await download_task
+        finally:
+            # Ensure the task is removed even if an exception occurs
+            if (shard, download_task) in self.active_downloads:
+                self.active_downloads.remove((shard, download_task))
+
+    async def _download_shard(self, shard: Shard) -> Path:
+        async def wrapped_progress_callback(event: RepoProgressEvent):
+            self._on_progress.trigger_all(shard, event)
+
+        weight_map = await get_weight_map(shard.model_id)
+        allow_patterns = self._get_allow_patterns(weight_map, shard.start_layer, shard.end_layer)
+
+        return await download_repo_files(
+            repo_id=shard.model_id,
+            progress_callback=wrapped_progress_callback,
+            allow_patterns=allow_patterns
+        )
+
+    @staticmethod
+    def _get_allow_patterns(weight_map: Dict[str, str], start_layer: int, end_layer: int) -> List[str]:
+        default_patterns = [
+            "*.json",
+            "*.py",
+            "tokenizer.model",
+            "*.tiktoken",
+            "*.txt",
+        ]
+        shard_specific_patterns = []
+        if weight_map:
+            for tensor_name, filename in weight_map.items():
+                layer_num = extract_layer_num(tensor_name)
+                if layer_num is not None and start_layer <= layer_num <= end_layer:
+                    shard_specific_patterns.append(filename)
+        else:
+            shard_specific_patterns = ["*.safetensors"]
+        return list(set(default_patterns + shard_specific_patterns))  # Remove duplicates
+
+    @property
+    def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
+        return self._on_progress

+ 25 - 0
exo/download/shard_download.py

@@ -0,0 +1,25 @@
+from abc import ABC, abstractmethod
+from typing import Optional, Tuple
+from pathlib import Path
+from exo.inference.shard import Shard
+from exo.download.download_progress import RepoProgressEvent
+from exo.helpers import AsyncCallbackSystem
+
+class ShardDownloader(ABC):
+    @abstractmethod
+    async def ensure_shard(self, shard: Shard) -> Path:
+        """
+        Ensures that the shard is downloaded.
+        Does not allow multiple overlapping downloads at once.
+        If you try to download a Shard which overlaps a Shard that is already being downloaded,
+        the download will be cancelled and a new download will start.
+
+        Args:
+            shard (Shard): The shard to download.
+        """
+        pass
+
+    @property
+    @abstractmethod
+    def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
+        pass

+ 3 - 3
exo/helpers.py

@@ -31,17 +31,17 @@ def get_system_info():
   return "Non-Mac, non-Linux system"
 
 
-def get_inference_engine(inference_engine_name):
+def get_inference_engine(inference_engine_name, shard_downloader: 'ShardDownloader'):
   if inference_engine_name == "mlx":
     from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
 
-    return MLXDynamicShardInferenceEngine()
+    return MLXDynamicShardInferenceEngine(shard_downloader)
   elif inference_engine_name == "tinygrad":
     from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
     import tinygrad.helpers
     tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
 
-    return TinygradDynamicShardInferenceEngine()
+    return TinygradDynamicShardInferenceEngine(shard_downloader)
   else:
     raise ValueError(f"Inference engine {inference_engine_name} not supported")
 

+ 1 - 6
exo/inference/inference_engine.py

@@ -1,9 +1,8 @@
 import numpy as np
 
-from typing import Tuple, Optional, Callable, Coroutine, Any
+from typing import Tuple, Optional
 from abc import ABC, abstractmethod
 from .shard import Shard
-from exo.inference.hf_helpers import HFRepoProgressEvent
 
 class InferenceEngine(ABC):
   @abstractmethod
@@ -13,7 +12,3 @@ class InferenceEngine(ABC):
   @abstractmethod
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
     pass
-
-  @abstractmethod
-  def set_progress_callback(self, progress_callback: Callable[[HFRepoProgressEvent], Coroutine[Any, Any, None]]):
-    pass

+ 6 - 8
exo/inference/mlx/sharded_inference_engine.py

@@ -4,14 +4,14 @@ from ..inference_engine import InferenceEngine
 from .sharded_model import StatefulShardedModel
 from .sharded_utils import load_shard, get_image_from_str
 from ..shard import Shard
-from typing import Optional, Callable
-from exo.inference.hf_helpers import HFRepoProgressCallback
+from typing import Optional
+from exo.download.shard_download import ShardDownloader
 
 
 class MLXDynamicShardInferenceEngine(InferenceEngine):
-  def __init__(self, progress_callback: Optional[HFRepoProgressCallback] = None):
+  def __init__(self, shard_downloader: ShardDownloader):
     self.shard = None
-    self.progress_callback = progress_callback
+    self.shard_downloader = shard_downloader
 
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     await self.ensure_shard(shard)
@@ -34,9 +34,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     if self.shard == shard:
       return
 
-    model_shard, self.tokenizer = await load_shard(shard.model_id, shard, progress_callback=self.progress_callback)
+    model_path = await self.shard_downloader.ensure_shard(shard)
+    model_shard, self.tokenizer = await load_shard(model_path, shard)
     self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
     self.shard = shard
-
-  def set_progress_callback(self, progress_callback: HFRepoProgressCallback):
-    self.progress_callback = progress_callback

+ 1 - 175
exo/inference/mlx/sharded_utils.py

@@ -15,18 +15,12 @@ import base64
 
 import mlx.core as mx
 import mlx.nn as nn
-from huggingface_hub import snapshot_download, list_repo_tree, get_paths_info
-from huggingface_hub.utils import filter_repo_objects
-from huggingface_hub.file_download import repo_folder_name
-from huggingface_hub.constants import HF_HUB_CACHE
-from huggingface_hub.utils._errors import RepositoryNotFoundError
 from transformers import AutoProcessor
 
 from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
 from mlx_lm.tuner.utils import apply_lora_layers
 
 from exo import DEBUG
-from exo.inference.hf_helpers import download_all_files, HFRepoProgressCallback
 from ..shard import Shard
 
 
@@ -162,182 +156,14 @@ def load_model_shard(
   model.eval()
   return model
 
-
-repo_id_safetensors_layers = {
-  "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit": {
-    "model.safetensors": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
-  },
-  "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit": {
-    "model-00001-of-00008.safetensors": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
-    "model-00002-of-00008.safetensors": [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
-    "model-00003-of-00008.safetensors": [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
-    "model-00004-of-00008.safetensors": [31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42],
-    "model-00005-of-00008.safetensors": [42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53],
-    "model-00006-of-00008.safetensors": [53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64],
-    "model-00007-of-00008.safetensors": [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75],
-    "model-00008-of-00008.safetensors": [75, 76, 77, 78, 79],
-  },
-  "mlx-community/Meta-Llama-3.1-405B-Instruct-4bit": {
-    "model-00001-of-00046.safetensors": [0, 1, 2],
-    "model-00002-of-00046.safetensors": [2, 3, 4, 5],
-    "model-00003-of-00046.safetensors": [5, 6, 7],
-    "model-00004-of-00046.safetensors": [8, 9, 10],
-    "model-00005-of-00046.safetensors": [10, 11, 12, 13],
-    "model-00006-of-00046.safetensors": [13, 14, 15, 16],
-    "model-00007-of-00046.safetensors": [16, 17, 18, 19],
-    "model-00008-of-00046.safetensors": [19, 20, 21],
-    "model-00009-of-00046.safetensors": [22, 23, 24],
-    "model-00010-of-00046.safetensors": [24, 25, 26, 27],
-    "model-00011-of-00046.safetensors": [27, 28, 29, 30],
-    "model-00012-of-00046.safetensors": [30, 31, 32, 33],
-    "model-00013-of-00046.safetensors": [33, 34, 35],
-    "model-00014-of-00046.safetensors": [36, 37, 38],
-    "model-00015-of-00046.safetensors": [38, 39, 40, 41],
-    "model-00016-of-00046.safetensors": [41, 42, 43, 44],
-    "model-00017-of-00046.safetensors": [44, 45, 46, 47],
-    "model-00018-of-00046.safetensors": [47, 48, 49],
-    "model-00019-of-00046.safetensors": [50, 51, 52],
-    "model-00020-of-00046.safetensors": [52, 53, 54, 55],
-    "model-00021-of-00046.safetensors": [55, 56, 57, 58],
-    "model-00022-of-00046.safetensors": [58, 59, 60, 61],
-    "model-00023-of-00046.safetensors": [61, 62, 63],
-    "model-00024-of-00046.safetensors": [64, 65, 66],
-    "model-00025-of-00046.safetensors": [66, 67, 68, 69],
-    "model-00026-of-00046.safetensors": [69, 70, 71, 72],
-    "model-00027-of-00046.safetensors": [72, 73, 74, 75],
-    "model-00028-of-00046.safetensors": [75, 76, 77],
-    "model-00029-of-00046.safetensors": [78, 79, 80],
-    "model-00030-of-00046.safetensors": [80, 81, 82, 83],
-    "model-00031-of-00046.safetensors": [83, 84, 85, 86],
-    "model-00032-of-00046.safetensors": [86, 87, 88, 89],
-    "model-00033-of-00046.safetensors": [89, 90, 91],
-    "model-00034-of-00046.safetensors": [92, 93, 94],
-    "model-00035-of-00046.safetensors": [94, 95, 96, 97],
-    "model-00036-of-00046.safetensors": [97, 98, 99, 100],
-    "model-00037-of-00046.safetensors": [100, 101, 102, 103],
-    "model-00038-of-00046.safetensors": [103, 104, 105],
-    "model-00039-of-00046.safetensors": [106, 107, 108],
-    "model-00040-of-00046.safetensors": [108, 109, 110, 111],
-    "model-00041-of-00046.safetensors": [111, 112, 113, 114],
-    "model-00042-of-00046.safetensors": [114, 115, 116, 117],
-    "model-00043-of-00046.safetensors": [117, 118, 119],
-    "model-00044-of-00046.safetensors": [120, 121, 122],
-    "model-00045-of-00046.safetensors": [122, 123, 124, 125],
-    "model-00046-of-00046.safetensors": [125]
-  },
-  "mlx-community/Mistral-Nemo-Instruct-2407-4bit": {
-    "model-00001-of-00002.safetensors": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32],
-    "model-00002-of-00002.safetensors": [32, 33, 34, 35, 36, 37, 38, 39],
-  },
-  "mlx-community/Mistral-Large-Instruct-2407-4bit": {
-    "model-00001-of-00014.safetensors": [0, 1, 2, 3, 4, 5, 6],
-    "model-00002-of-00014.safetensors": [6, 7, 8, 9, 10, 11, 12, 13],
-    "model-00003-of-00014.safetensors": [13, 14, 15, 16, 17, 18, 19, 20],
-    "model-00004-of-00014.safetensors": [20, 21, 22, 23, 24, 25, 26],
-    "model-00005-of-00014.safetensors": [27, 28, 29, 30, 31, 32, 33],
-    "model-00006-of-00014.safetensors": [33, 34, 35, 36, 37, 38, 39, 40],
-    "model-00007-of-00014.safetensors": [40, 41, 42, 43, 44, 45, 46, 47],
-    "model-00008-of-00014.safetensors": [47, 48, 49, 50, 51, 52, 53, 54],
-    "model-00009-of-00014.safetensors": [54, 55, 56, 57, 58, 59, 60],
-    "model-00010-of-00014.safetensors": [61, 62, 63, 64, 65, 66, 67],
-    "model-00011-of-00014.safetensors": [67, 68, 69, 70, 71, 72, 73, 74],
-    "model-00012-of-00014.safetensors": [74, 75, 76, 77, 78, 79, 80, 81],
-    "model-00013-of-00014.safetensors": [81, 82, 83, 84, 85, 86, 87],
-    "model-00014-of-00014.safetensors": [87]
-  },
-  "llava-hf/llava-1.5-7b-hf": {
-    "model-00001-of-00003.safetensors": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
-    "model-00002-of-00003.safetensors": [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22],
-    "model-00003-of-00003.safetensors": [22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
-  }
-}
-
-def get_safetensors_allow_patterns(repo_id: str, shard: Optional[Shard] = None):
-    return ["*.safetensors"] # TODO: enable this
-    if not shard:
-      return ["*.safetensors"]
-
-    allow_patterns = []
-    for repo_id, safetensors_layers in repo_id_safetensors_layers.items():
-        if repo_id == shard.model_id:
-            for safetensor, layers in safetensors_layers.items():
-                if any(shard.start_layer <= layer <= shard.end_layer for layer in layers):
-                    allow_patterns.append(safetensor)
-
-    return allow_patterns if len(allow_patterns) > 0 else ["*.safetensors"]
-
-async def get_model_path(path_or_hf_repo: str, shard: Optional[Shard] = None, revision: str = "main", progress_callback: Optional[HFRepoProgressCallback] = None) -> Path:
-  """
-  Ensures the model is available locally. If the path does not exist locally,
-  it is downloaded from the Hugging Face Hub.
-
-  Args:
-   path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
-   revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
-
-  Returns:
-   Path: The path to the model.
-  """
-  model_path = Path(path_or_hf_repo)
-  if not model_path.exists():
-    try:
-      model_path = Path(
-        await download_all_files(
-          repo_id=path_or_hf_repo,
-          revision=revision,
-          allow_patterns=[
-            "*.json",
-            "*.py",
-            "tokenizer.model",
-            "*.tiktoken",
-            "*.txt",
-          ] + get_safetensors_allow_patterns(path_or_hf_repo, shard),
-          progress_callback=progress_callback,
-        )
-      )
-    except RepositoryNotFoundError:
-      raise ModelNotFoundError(
-        f"Model not found for path or HF repo: {path_or_hf_repo}.\n"
-        "Please make sure you specified the local path or Hugging Face"
-        " repo id correctly.\nIf you are trying to access a private or"
-        " gated Hugging Face repo, make sure you are authenticated:\n"
-        "https://huggingface.co/docs/huggingface_hub/en/guides/cli#huggingface-cli-login"
-      ) from None
-  return model_path
-
-
 async def load_shard(
-  path_or_hf_repo: str,
+  model_path: str,
   shard: Shard,
   tokenizer_config={},
   model_config={},
   adapter_path: Optional[str] = None,
   lazy: bool = False,
-  progress_callback: Optional[HFRepoProgressCallback] = None,
 ) -> Tuple[nn.Module, TokenizerWrapper]:
-  """
-  Load the model and tokenizer from a given path or a huggingface repository.
-
-  Args:
-   path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
-   tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
-    Defaults to an empty dictionary.
-   model_config(dict, optional): Configuration parameters specifically for the model.
-    Defaults to an empty dictionary.
-   adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
-    to the model. Default: ``None``.
-   lazy (bool): If False eval the model parameters to make sure they are
-    loaded in memory before returning, otherwise they will be loaded
-    when needed. Default: ``False``
-  Returns:
-   Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
-
-  Raises:
-   FileNotFoundError: If config file or safetensors are not found.
-   ValueError: If model class or args class are not found.
-  """
-  model_path = await get_model_path(path_or_hf_repo, shard, progress_callback=progress_callback)
-
   model = load_model_shard(model_path, shard, lazy, model_config)
   if adapter_path is not None:
     model = apply_lora_layers(model, adapter_path)

+ 9 - 0
exo/inference/shard.py

@@ -24,3 +24,12 @@ class Shard:
       "end_layer": self.end_layer,
       "n_layers": self.n_layers,
     }
+
+  def overlaps(self, other: 'Shard') -> bool:
+    return shards_overlap(self, other)
+
+def shards_overlap(shard1: Shard, shard2: Shard) -> bool:
+  return (
+      shard1.model_id == shard2.model_id
+      and max(shard1.start_layer, shard2.start_layer) <= min(shard1.end_layer, shard2.end_layer)
+  )

+ 5 - 8
exo/inference/tinygrad/inference.py

@@ -1,5 +1,5 @@
 from pathlib import Path
-from typing import List, Optional, Union, Callable, Coroutine, Any
+from typing import List, Optional
 import json
 from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
 from tinygrad.nn.state import safe_load, torch_load, load_state_dict
@@ -8,7 +8,7 @@ from tinygrad.helpers import tqdm
 from exo.inference.shard import Shard
 from exo.inference.inference_engine import InferenceEngine
 import numpy as np
-from exo.inference.hf_helpers import HFRepoProgressCallback, HFRepoProgressEvent, download_all_files
+from exo.download.shard_download import ShardDownloader
 
 MODEL_PARAMS = {
   "8B": {
@@ -147,9 +147,9 @@ def prefill(model, toks, start_pos=0):
 
 
 class TinygradDynamicShardInferenceEngine(InferenceEngine):
-  def __init__(self, progress_callback: Optional[HFRepoProgressCallback] = None):
+  def __init__(self, shard_downloader: ShardDownloader):
     self.shard = None
-    self.progress_callback = progress_callback
+    self.shard_downloader = shard_downloader
 
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     # TODO: we need to refactor models/llamaa to handle per-request-kv-cache. right now it's shared between requests.
@@ -188,7 +188,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     if self.shard == shard:
       return
 
-    model_path = await download_all_files(shard.model_id, progress_callback=self.progress_callback)
+    model_path = await self.shard_downloader.ensure_shard(shard)
     print(f"{model_path=}")
     model = build_transformer(model_path, shard=shard, model_size="8B" if "8b" in shard.model_id else "70B" if "70b" in shard.model_id else "8B")
     from transformers import AutoTokenizer
@@ -197,6 +197,3 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     self.shard = shard
     self.model = model
     self.tokenizer = tokenizer
-
-  def set_progress_callback(self, progress_callback: Callable[[HFRepoProgressEvent], Coroutine[Any, Any, None]]):
-    self.progress_callback = progress_callback

+ 8 - 8
exo/orchestration/standard_node.py

@@ -14,7 +14,7 @@ from exo.topology.partitioning_strategy import Partition, PartitioningStrategy,
 from exo import DEBUG
 from exo.helpers import AsyncCallbackSystem
 from exo.viz.topology_viz import TopologyViz
-from exo.inference.hf_helpers import HFRepoProgressEvent
+from exo.download.hf.hf_helpers import RepoProgressEvent
 
 
 class StandardNode(Node):
@@ -25,7 +25,7 @@ class StandardNode(Node):
     inference_engine: InferenceEngine,
     discovery: Discovery,
     partitioning_strategy: PartitioningStrategy = None,
-    max_generate_tokens: int = 256,
+    max_generate_tokens: int = 1024,
     chatgpt_api_endpoint: Optional[str] = None,
     web_chat_url: Optional[str] = None,
     disable_tui: Optional[bool] = False,
@@ -44,6 +44,7 @@ class StandardNode(Node):
     self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
     self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
     self._on_opaque_status.register("node_status").on_next(self.on_node_status)
+    self.node_download_progress: Dict[str, RepoProgressEvent] = {}
 
   def on_node_status(self, request_id, opaque_status):
     try:
@@ -57,14 +58,13 @@ class StandardNode(Node):
       download_progress = None
       if status_data.get("type", "") == "download_progress":
         if DEBUG >= 5: print(f"Download progress from {status_data.get('node_id')}: {status_data.get('progress')}")
-        if status_data.get("node_id") == self.id:
-          download_progress = HFRepoProgressEvent.from_dict(status_data.get('progress'))
+        download_progress = RepoProgressEvent.from_dict(status_data.get('progress'))
+        self.node_download_progress[status_data.get('node_id')] = download_progress
       if self.topology_viz:
-        self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), download_progress)
+        self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), self.id, self.node_download_progress)
     except Exception as e:
       if DEBUG >= 1: print(f"Error updating visualization: {e}")
-      traceback.print_exc()
-      pass
+      if DEBUG >= 1: traceback.print_exc()
 
   async def start(self, wait_for_peers: int = 0) -> None:
     await self.server.start()
@@ -347,7 +347,7 @@ class StandardNode(Node):
     next_topology.active_node_id = self.topology.active_node_id  # this is not so clean.
     self.topology = next_topology
     if self.topology_viz:
-      self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology))
+      self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), self.id)
     return next_topology
 
   @property

+ 61 - 1
exo/viz/test_topology_viz.py

@@ -1,9 +1,57 @@
 import asyncio
 import unittest
+from datetime import timedelta
 from exo.viz.topology_viz import TopologyViz
 from exo.topology.topology import Topology
 from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
 from exo.topology.partitioning_strategy import Partition
+from exo.download.hf.hf_helpers import RepoProgressEvent, RepoFileProgressEvent
+
+
+def create_hf_repo_progress_event(
+    completed_files: int = 5,
+    total_files: int = 10,
+    downloaded_bytes: int = 500000000,
+    downloaded_bytes_this_session: int = 250000000,
+    total_bytes: int = 1000000000,
+    overall_speed: int = 5000000,
+    overall_eta: timedelta = timedelta(seconds=100),
+    file_progress: dict = None,
+    status: str = "in_progress"
+) -> RepoProgressEvent:
+    if file_progress is None:
+        file_progress = {
+            "file1.bin": RepoFileProgressEvent(
+                file_path="file1.bin",
+                downloaded=100000000,
+                downloaded_this_session=50000000,
+                total=200000000,
+                speed=1000000,
+                eta=timedelta(seconds=100),
+                status="in_progress"
+            ),
+            "file2.bin": RepoFileProgressEvent(
+                file_path="file2.bin",
+                downloaded=200000000,
+                downloaded_this_session=100000000,
+                total=200000000,
+                speed=2000000,
+                eta=timedelta(seconds=0),
+                status="complete"
+            )
+        }
+
+    return RepoProgressEvent(
+        completed_files=completed_files,
+        total_files=total_files,
+        downloaded_bytes=downloaded_bytes,
+        downloaded_bytes_this_session=downloaded_bytes_this_session,
+        total_bytes=total_bytes,
+        overall_speed=overall_speed,
+        overall_eta=overall_eta,
+        file_progress=file_progress,
+        status=status
+    )
 
 
 class TestNodeViz(unittest.IsolatedAsyncioTestCase):
@@ -30,7 +78,7 @@ class TestNodeViz(unittest.IsolatedAsyncioTestCase):
     await asyncio.sleep(2)  # Simulate running for a short time
 
   async def test_layout_generation(self):
-    self.top_viz._generate_layout()
+    # self.top_viz._generate_layout()
     self.top_viz.refresh()
     import time
 
@@ -43,6 +91,13 @@ class TestNodeViz(unittest.IsolatedAsyncioTestCase):
         Partition("node2", 0.4, 0.8),
         Partition("node3", 0.8, 0.9),
       ],
+      "node1",
+      {
+        "node1": create_hf_repo_progress_event(),
+        "node2": create_hf_repo_progress_event(),
+        "node3": create_hf_repo_progress_event(),
+        "node4": create_hf_repo_progress_event(),
+      },
     )
     time.sleep(2)
     self.topology.active_node_id = "node3"
@@ -54,6 +109,11 @@ class TestNodeViz(unittest.IsolatedAsyncioTestCase):
         Partition("node2", 0.5, 0.7),
         Partition("node4", 0.7, 0.9),
       ],
+      "node5",
+      {
+        "node1": create_hf_repo_progress_event(),
+        "node5": create_hf_repo_progress_event(),
+      },
     )
     time.sleep(2)
 

+ 88 - 76
exo/viz/topology_viz.py

@@ -1,17 +1,17 @@
 import math
-from typing import List, Optional, Tuple
+from typing import List, Optional, Tuple, Dict
 from exo.helpers import exo_text, pretty_print_bytes, pretty_print_bytes_per_second
 from exo.topology.topology import Topology
 from exo.topology.partitioning_strategy import Partition
+from exo.download.hf.hf_helpers import RepoProgressEvent
 from rich.console import Console
 from rich.panel import Panel
 from rich.text import Text
 from rich.live import Live
 from rich.style import Style
-from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn
 from rich.table import Table
+from rich.layout import Layout
 from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
-from exo.inference.hf_helpers import HFRepoProgressEvent
 
 class TopologyViz:
   def __init__(self, chatgpt_api_endpoint: str = None, web_chat_url: str = None):
@@ -19,60 +19,54 @@ class TopologyViz:
     self.web_chat_url = web_chat_url
     self.topology = Topology()
     self.partitions: List[Partition] = []
-    self.download_progress = None
+    self.node_id = None
+    self.node_download_progress: Dict[str, RepoProgressEvent] = {}
 
     self.console = Console()
-    self.panel = Panel(self._generate_layout(), title="Exo Cluster (0 nodes)", border_style="bright_yellow")
-    self.live_panel = Live(self.panel, auto_refresh=False, console=self.console)
+    self.layout = Layout()
+    self.layout.split(
+      Layout(name="main"),
+      Layout(name="download", size=15)
+    )
+    self.main_panel = Panel(self._generate_main_layout(), title="Exo Cluster (0 nodes)", border_style="bright_yellow")
+    self.download_panel = Panel("", title="Download Progress", border_style="cyan")
+    self.layout["main"].update(self.main_panel)
+    self.layout["download"].update(self.download_panel)
+    self.live_panel = Live(self.layout, auto_refresh=False, console=self.console)
     self.live_panel.start()
 
-  def update_visualization(self, topology: Topology, partitions: List[Partition], download_progress: HFRepoProgressEvent = None):
+  def update_visualization(self, topology: Topology, partitions: List[Partition], node_id: Optional[str] = None, node_download_progress: Dict[str, RepoProgressEvent] = {}):
     self.topology = topology
     self.partitions = partitions
-    self.download_progress = download_progress
+    self.node_id = node_id
+    if node_download_progress:
+      self.node_download_progress = node_download_progress
     self.refresh()
 
   def refresh(self):
-    self.panel.renderable = self._generate_layout()
+    self.main_panel.renderable = self._generate_main_layout()
     # Update the panel title with the number of nodes and partitions
     node_count = len(self.topology.nodes)
-    self.panel.title = f"Exo Cluster ({node_count} node{'s' if node_count != 1 else ''})"
-    self.live_panel.update(self.panel, refresh=True)
+    self.main_panel.title = f"Exo Cluster ({node_count} node{'s' if node_count != 1 else ''})"
 
-  def _generate_layout(self) -> str:
+    # Only show download_panel if there are in-progress downloads
+    if any(progress.status == "in_progress" for progress in self.node_download_progress.values()):
+      self.download_panel.renderable = self._generate_download_layout()
+      self.layout["download"].visible = True
+    else:
+      self.layout["download"].visible = False
+
+    self.live_panel.update(self.layout, refresh=True)
+
+  def _generate_main_layout(self) -> str:
     # Calculate visualization parameters
     num_partitions = len(self.partitions)
-    radius_x = 30  # Increased horizontal radius
-    radius_y = 12  # Decreased vertical radius
-    center_x, center_y = 50, 28  # Centered horizontally and moved up slightly
+    radius_x = 30
+    radius_y = 12
+    center_x, center_y = 50, 24  # Increased center_y to add more space
 
     # Generate visualization
-    visualization = [[" " for _ in range(100)] for _ in range(55)]  # Decreased height
-
-    # Draw download first so everything else is drawn on top
-    # If a download is in progress, show the download info summary
-    if self.download_progress and self.download_progress.status != "complete":
-        download_summary = _generate_download_summary(self.download_progress)
-        download_panel = Panel(
-            download_summary,
-            title="Download Progress",
-            border_style="cyan",
-            expand=False,
-            width=96,  # Further reduced to ensure it fits within the visualization
-            height=None  # Allow the panel to adjust its height based on content
-        )
-        console = Console(width=98, height=55)  # Reduced console width
-        with console.capture() as capture:
-            console.print(download_panel)
-        download_lines = capture.get().split('\n')
-        download_start_y = 15
-        panel_width = len(max(download_lines, key=len))
-        start_x = max(1, (100 - panel_width) // 2)  # Ensure start_x is at least 1 to avoid left border cut-off
-        for i, line in enumerate(download_lines):
-            for j, char in enumerate(line):
-                if 1 <= start_x + j < 99 and download_start_y + i < 55:  # Ensure we don't write to the rightmost column
-                    visualization[download_start_y + i][start_x + j] = char
-
+    visualization = [[" " for _ in range(100)] for _ in range(48)]  # Increased height to 48
 
     # Add exo_text at the top in bright yellow
     exo_lines = exo_text.split("\n")
@@ -80,7 +74,7 @@ class TopologyViz:
     max_line_length = max(len(line) for line in exo_lines)
     for i, line in enumerate(exo_lines):
       centered_line = line.center(max_line_length)
-      start_x = (100 - max_line_length) // 2 + 15  # Center the text plus empirical adjustment of 15
+      start_x = (100 - max_line_length) // 2 + 15
       colored_line = Text(centered_line, style=yellow_style)
       for j, char in enumerate(str(colored_line)):
         if 0 <= start_x + j < 100 and i < len(visualization):
@@ -95,9 +89,9 @@ class TopologyViz:
 
     info_start_y = len(exo_lines) + 1
     for i, line in enumerate(info_lines):
-      start_x = (100 - len(line)) // 2 + 15  # Center the info lines plus empirical adjustment of 15
+      start_x = (100 - len(line)) // 2 + 15
       for j, char in enumerate(line):
-        if 0 <= start_x + j < 100 and info_start_y + i < 55:
+        if 0 <= start_x + j < 100 and info_start_y + i < 48:
           visualization[info_start_y + i][start_x + j] = char
 
     # Calculate total FLOPS and position on the bar
@@ -105,13 +99,13 @@ class TopologyViz:
     bar_pos = (math.tanh(total_flops / 20 - 2) + 1) / 2
 
     # Add GPU poor/rich bar
-    bar_width = 30  # Increased bar width
-    bar_start_x = (100 - bar_width) // 2  # Center the bar
-    bar_y = info_start_y + len(info_lines) + 1  # Position the bar below the info section with two cells of space
+    bar_width = 30
+    bar_start_x = (100 - bar_width) // 2
+    bar_y = info_start_y + len(info_lines) + 1
 
     # Create a gradient bar using emojis
     gradient_bar = Text()
-    emojis = ["🟥", "🟧", "🟨", "🟩"]  # Red, Orange, Yellow, Green
+    emojis = ["🟥", "🟧", "🟨", "🟩"]
     for i in range(bar_width):
       emoji_index = min(int(i / (bar_width / len(emojis))), len(emojis) - 1)
       gradient_bar.append(emojis[emoji_index])
@@ -133,6 +127,9 @@ class TopologyViz:
     visualization[bar_y + 1][pos_x - len(flops_str) // 2 : pos_x + len(flops_str) // 2 + len(flops_str) % 2] = flops_str
     visualization[bar_y + 2][pos_x] = "▲"
 
+    # Add an extra empty line for spacing
+    bar_y += 4
+
     for i, partition in enumerate(self.partitions):
       device_capabilities = self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES)
 
@@ -140,11 +137,13 @@ class TopologyViz:
       x = int(center_x + radius_x * math.cos(angle))
       y = int(center_y + radius_y * math.sin(angle))
 
-      # Place node with different color for active node
+      # Place node with different color for active node and this node
       if partition.node_id == self.topology.active_node_id:
-        visualization[y][x] = "🔴"  # Red circle for active node
+        visualization[y][x] = "🔴"
+      elif partition.node_id == self.node_id:
+        visualization[y][x] = "🟢"
       else:
-        visualization[y][x] = "🔵"  # Blue circle for inactive nodes
+        visualization[y][x] = "🔵"
 
       # Place node info (model, memory, TFLOPS, partition) on three lines
       node_info = [
@@ -154,28 +153,27 @@ class TopologyViz:
       ]
 
       # Calculate info position based on angle
-      info_distance_x = radius_x + 6  # Increased horizontal distance
-      info_distance_y = radius_y + 3  # Decreased vertical distance
+      info_distance_x = radius_x + 6
+      info_distance_y = radius_y + 3
       info_x = int(center_x + info_distance_x * math.cos(angle))
       info_y = int(center_y + info_distance_y * math.sin(angle))
 
       # Adjust text position to avoid overwriting the node icon and prevent cutoff
-      if info_x < x:  # Text is to the left of the node
+      if info_x < x:
         info_x = max(0, x - len(max(node_info, key=len)) - 1)
-      elif info_x > x:  # Text is to the right of the node
+      elif info_x > x:
         info_x = min(99 - len(max(node_info, key=len)), info_x)
 
       # Adjust for top and bottom nodes
-      if 5 * math.pi / 4 < angle < 7 * math.pi / 4:  # Node is near the top
-        info_x += 4  # Shift text slightly to the right
-      elif math.pi / 4 < angle < 3 * math.pi / 4:  # Node is near the bottom
-        info_x += 3  # Shift text slightly to the right
-        info_y -= 2  # Move text up by two cells
+      if 5 * math.pi / 4 < angle < 7 * math.pi / 4:
+        info_x += 4
+      elif math.pi / 4 < angle < 3 * math.pi / 4:
+        info_x += 3
+        info_y -= 2
 
       for j, line in enumerate(node_info):
         for k, char in enumerate(line):
-          if 0 <= info_y + j < 55 and 0 <= info_x + k < 100:  # Updated height check
-            # Ensure we're not overwriting the node icon
+          if 0 <= info_y + j < 48 and 0 <= info_x + k < 100:
             if info_y + j != y or info_x + k != x:
               visualization[info_y + j][info_x + k] = char
 
@@ -190,33 +188,47 @@ class TopologyViz:
       for step in range(1, steps):
         line_x = int(x + (next_x - x) * step / steps)
         line_y = int(y + (next_y - y) * step / steps)
-        if 0 <= line_y < 55 and 0 <= line_x < 100:  # Updated height check
+        if 0 <= line_y < 48 and 0 <= line_x < 100:
           visualization[line_y][line_x] = "-"
 
     # Convert to string
     return "\n".join("".join(str(char) for char in row) for row in visualization)
 
-def _generate_download_summary(download_progress) -> Table:
+  def _generate_download_layout(self) -> Table:
     summary = Table(show_header=False, box=None, padding=(0, 1))
     summary.add_column("Info", style="cyan", no_wrap=True)
     summary.add_column("Progress", style="cyan", no_wrap=True)
     summary.add_column("Percentage", style="cyan", no_wrap=True)
 
-    title = f"Downloading model ({download_progress.completed_files}/{download_progress.total_files}):"
-    summary.add_row(Text(title, style="bold"))
-    progress_info = f"{pretty_print_bytes(download_progress.downloaded_bytes)} / {pretty_print_bytes(download_progress.total_bytes)} ({pretty_print_bytes_per_second(download_progress.overall_speed)})"
-    summary.add_row(progress_info)
+    # Current node download progress
+    if self.node_id in self.node_download_progress:
+        download_progress = self.node_download_progress[self.node_id]
+        title = f"Downloading model ({download_progress.completed_files}/{download_progress.total_files}):"
+        summary.add_row(Text(title, style="bold"))
+        progress_info = f"{pretty_print_bytes(download_progress.downloaded_bytes)} / {pretty_print_bytes(download_progress.total_bytes)} ({pretty_print_bytes_per_second(download_progress.overall_speed)})"
+        summary.add_row(progress_info)
+
+        eta_info = f"ETA: {download_progress.overall_eta}"
+        summary.add_row(eta_info)
+
+        summary.add_row("")  # Empty row for spacing
 
-    eta_info = f"ETA: {download_progress.overall_eta}"
-    summary.add_row(eta_info)
+        for file_path, file_progress in download_progress.file_progress.items():
+            if file_progress.status != "complete":
+                progress = int(file_progress.downloaded / file_progress.total * 20)
+                bar = f"[{'=' * progress}{' ' * (20 - progress)}]"
+                percentage = f"{file_progress.downloaded / file_progress.total * 100:.0f}%"
+                summary.add_row(Text(file_path[:20], style="cyan"), bar, percentage)
 
     summary.add_row("")  # Empty row for spacing
 
-    for file_path, file_progress in download_progress.file_progress.items():
-      if file_progress.status != "complete":
-        progress = int(file_progress.downloaded / file_progress.total * 20)  # Increased bar width
-        bar = f"[{'=' * progress}{' ' * (20 - progress)}]"
-        percentage = f"{file_progress.downloaded / file_progress.total * 100:.0f}%"
-        summary.add_row(Text(file_path[:20], style="cyan"), bar, percentage)  # Increased file path length
+    # Other nodes download progress summary
+    summary.add_row(Text("Other Nodes Download Progress:", style="bold"))
+    for node_id, progress in self.node_download_progress.items():
+        if node_id != self.node_id:
+            truncated_id = node_id[:8] + "..." if len(node_id) > 8 else node_id
+            percentage = progress.downloaded_bytes / progress.total_bytes * 100 if progress.total_bytes > 0 else 0
+            speed = pretty_print_bytes_per_second(progress.overall_speed)
+            summary.add_row(f"{truncated_id}: {percentage:.1f}% ({speed})")
 
     return summary

+ 2 - 2
extra/download_hf.py

@@ -1,6 +1,6 @@
 import argparse
 import asyncio
-from exo.inference.hf_helpers import download_all_files, HFRepoProgressEvent, HFRepoFileProgressEvent
+from exo.download.hf.hf_helpers import download_all_files, RepoProgressEvent
 
 DEFAULT_ALLOW_PATTERNS = [
     "*.json",
@@ -23,7 +23,7 @@ DEFAULT_IGNORE_PATTERNS = [
 ]
 
 async def main(repo_id, revision="main", allow_patterns=None, ignore_patterns=None):
-    async def progress_callback(event: HFRepoProgressEvent):
+    async def progress_callback(event: RepoProgressEvent):
         print(f"Overall Progress: {event.completed_files}/{event.total_files} files, {event.downloaded_bytes}/{event.total_bytes} bytes")
         print(f"Estimated time remaining: {event.overall_eta}")
         print("File Progress:")

+ 16 - 5
main.py

@@ -2,12 +2,14 @@ import argparse
 import asyncio
 import signal
 import json
-import uuid
+import time
 from exo.orchestration.standard_node import StandardNode
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.grpc.grpc_discovery import GRPCDiscovery
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 from exo.api import ChatGPTAPI
+from exo.download.shard_download import ShardDownloader
+from exo.download.hf.hf_shard_download import HFShardDownloader
 from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_inference_engine, get_system_info, get_or_create_node_id
 
 # parse args
@@ -22,7 +24,7 @@ parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery
 parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
 parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
 parser.add_argument("--chatgpt-api-response-timeout-secs", type=int, default=90, help="ChatGPT API response timeout in seconds")
-parser.add_argument("--max-generate-tokens", type=int, default=256, help="Max tokens to generate in each request")
+parser.add_argument("--max-generate-tokens", type=int, default=1024, help="Max tokens to generate in each request")
 parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use")
 parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
 args = parser.parse_args()
@@ -32,9 +34,10 @@ print_yellow_exo()
 system_info = get_system_info()
 print(f"Detected system: {system_info}")
 
+shard_downloader: ShardDownloader = HFShardDownloader()
 inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
-inference_engine = get_inference_engine(inference_engine_name)
-print(f"Using inference engine: {inference_engine.__class__.__name__}")
+inference_engine = get_inference_engine(inference_engine_name, shard_downloader)
+print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")
 
 if args.node_port is None:
     args.node_port = find_available_port(args.node_host)
@@ -60,7 +63,15 @@ node.on_token.register("main_log").on_next(lambda _, tokens, __: print(inference
 if args.prometheus_client_port:
     from exo.stats.metrics import start_metrics_server
     start_metrics_server(node, args.prometheus_client_port)
-inference_engine.set_progress_callback(lambda event: asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()}))))
+
+last_broadcast_time = 0
+def throttled_broadcast(shard, event):
+    global last_broadcast_time
+    current_time = time.time()
+    if current_time - last_broadcast_time >= 0.1:
+        last_broadcast_time = current_time
+        asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
+shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
 
 async def shutdown(signal, loop):
     """Gracefully shutdown the server and close the asyncio loop."""