Explorar el Código

make a separate ShardDownloader abstract class w HFShardDownloader. this opens up plugging in different methods of downloading model shards e.g. #79 / #16

Alex Cheema hace 1 año
padre
commit
476a714bbb

+ 1 - 4
exo/api/chatgpt_api.py

@@ -103,10 +103,7 @@ async def resolve_tokenizer(model_id: str):
     if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer. Error: {e}")
     if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer. Error: {e}")
     if DEBUG >= 4: print(traceback.format_exc())
     if DEBUG >= 4: print(traceback.format_exc())
 
 
-  if DEBUG >= 4: print(f"Trying mlx tokenizer for {model_id}")
-  from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
-
-  return load_tokenizer(await get_model_path(model_id))
+  raise ValueError(f"[TODO] Unsupported model: {model_id}")
 
 
 
 
 def generate_completion(
 def generate_completion(

+ 74 - 0
exo/download/download_progress.py

@@ -0,0 +1,74 @@
+from typing import Dict, Callable, Coroutine, Any, Literal
+from dataclasses import dataclass
+from datetime import timedelta
+
+@dataclass
+class RepoFileProgressEvent:
+    file_path: str
+    downloaded: int
+    downloaded_this_session: int
+    total: int
+    speed: int
+    eta: timedelta
+    status: Literal["not_started", "in_progress", "complete"]
+
+    def to_dict(self):
+        return {
+            "file_path": self.file_path,
+            "downloaded": self.downloaded,
+            "downloaded_this_session": self.downloaded_this_session,
+            "total": self.total,
+            "speed": self.speed,
+            "eta": self.eta.total_seconds(),
+            "status": self.status
+        }
+
+    @classmethod
+    def from_dict(cls, data):
+        # Convert eta from seconds back to timedelta
+        if 'eta' in data:
+            data['eta'] = timedelta(seconds=data['eta'])
+        return cls(**data)
+
+@dataclass
+class RepoProgressEvent:
+    completed_files: int
+    total_files: int
+    downloaded_bytes: int
+    downloaded_bytes_this_session: int
+    total_bytes: int
+    overall_speed: int
+    overall_eta: timedelta
+    file_progress: Dict[str, RepoFileProgressEvent]
+    status: Literal["not_started", "in_progress", "complete"]
+
+    def to_dict(self):
+        return {
+            "completed_files": self.completed_files,
+            "total_files": self.total_files,
+            "downloaded_bytes": self.downloaded_bytes,
+            "downloaded_bytes_this_session": self.downloaded_bytes_this_session,
+            "total_bytes": self.total_bytes,
+            "overall_speed": self.overall_speed,
+            "overall_eta": self.overall_eta.total_seconds(),
+            "file_progress": {k: v.to_dict() for k, v in self.file_progress.items()},
+            "status": self.status
+        }
+
+    @classmethod
+    def from_dict(cls, data):
+        # Convert overall_eta from seconds back to timedelta
+        if 'overall_eta' in data:
+            data['overall_eta'] = timedelta(seconds=data['overall_eta'])
+
+        # Parse file_progress
+        if 'file_progress' in data:
+            data['file_progress'] = {
+                k: RepoFileProgressEvent.from_dict(v)
+                for k, v in data['file_progress'].items()
+            }
+
+        return cls(**data)
+
+RepoFileProgressCallback = Callable[[RepoFileProgressEvent], Coroutine[Any, Any, None]]
+RepoProgressCallback = Callable[[RepoProgressEvent], Coroutine[Any, Any, None]]

+ 52 - 89
exo/inference/hf_helpers.py → exo/download/hf/hf_helpers.py

@@ -1,5 +1,6 @@
 import asyncio
 import asyncio
 import aiohttp
 import aiohttp
+import json
 import os
 import os
 from urllib.parse import urljoin
 from urllib.parse import urljoin
 from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
 from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
@@ -7,9 +8,9 @@ from datetime import datetime, timedelta
 from fnmatch import fnmatch
 from fnmatch import fnmatch
 from pathlib import Path
 from pathlib import Path
 from typing import Generator, Iterable, TypeVar, TypedDict
 from typing import Generator, Iterable, TypeVar, TypedDict
-from dataclasses import dataclass
 from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
 from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
 from exo.helpers import DEBUG
 from exo.helpers import DEBUG
+from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback
 
 
 T = TypeVar("T")
 T = TypeVar("T")
 def filter_repo_objects(
 def filter_repo_objects(
@@ -21,10 +22,8 @@ def filter_repo_objects(
 ) -> Generator[T, None, None]:
 ) -> Generator[T, None, None]:
     if isinstance(allow_patterns, str):
     if isinstance(allow_patterns, str):
         allow_patterns = [allow_patterns]
         allow_patterns = [allow_patterns]
-
     if isinstance(ignore_patterns, str):
     if isinstance(ignore_patterns, str):
         ignore_patterns = [ignore_patterns]
         ignore_patterns = [ignore_patterns]
-
     if allow_patterns is not None:
     if allow_patterns is not None:
         allow_patterns = [_add_wildcard_to_directories(p) for p in allow_patterns]
         allow_patterns = [_add_wildcard_to_directories(p) for p in allow_patterns]
     if ignore_patterns is not None:
     if ignore_patterns is not None:
@@ -37,18 +36,14 @@ def filter_repo_objects(
             if isinstance(item, Path):
             if isinstance(item, Path):
                 return str(item)
                 return str(item)
             raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.")
             raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.")
-
         key = _identity
         key = _identity
 
 
     for item in items:
     for item in items:
         path = key(item)
         path = key(item)
-
         if allow_patterns is not None and not any(fnmatch(path, r) for r in allow_patterns):
         if allow_patterns is not None and not any(fnmatch(path, r) for r in allow_patterns):
             continue
             continue
-
         if ignore_patterns is not None and any(fnmatch(path, r) for r in ignore_patterns):
         if ignore_patterns is not None and any(fnmatch(path, r) for r in ignore_patterns):
             continue
             continue
-
         yield item
         yield item
 
 
 def _add_wildcard_to_directories(pattern: str) -> str:
 def _add_wildcard_to_directories(pattern: str) -> str:
@@ -99,84 +94,13 @@ async def fetch_file_list(session, repo_id, revision, path=""):
             raise Exception(f"Failed to fetch file list: {response.status}")
             raise Exception(f"Failed to fetch file list: {response.status}")
 
 
 
 
-@dataclass
-class HFRepoFileProgressEvent:
-    file_path: str
-    downloaded: int
-    downloaded_this_session: int
-    total: int
-    speed: int
-    eta: timedelta
-    status: Literal["not_started", "in_progress", "complete"]
-
-    def to_dict(self):
-        return {
-            "file_path": self.file_path,
-            "downloaded": self.downloaded,
-            "downloaded_this_session": self.downloaded_this_session,
-            "total": self.total,
-            "speed": self.speed,
-            "eta": self.eta.total_seconds(),
-            "status": self.status
-        }
-
-    @classmethod
-    def from_dict(cls, data):
-        # Convert eta from seconds back to timedelta
-        if 'eta' in data:
-            data['eta'] = timedelta(seconds=data['eta'])
-        return cls(**data)
-
-@dataclass
-class HFRepoProgressEvent:
-    completed_files: int
-    total_files: int
-    downloaded_bytes: int
-    downloaded_bytes_this_session: int
-    total_bytes: int
-    overall_speed: int
-    overall_eta: timedelta
-    file_progress: Dict[str, HFRepoFileProgressEvent]
-    status: Literal["not_started", "in_progress", "complete"]
-
-    def to_dict(self):
-        return {
-            "completed_files": self.completed_files,
-            "total_files": self.total_files,
-            "downloaded_bytes": self.downloaded_bytes,
-            "downloaded_bytes_this_session": self.downloaded_bytes_this_session,
-            "total_bytes": self.total_bytes,
-            "overall_speed": self.overall_speed,
-            "overall_eta": self.overall_eta.total_seconds(),
-            "file_progress": {k: v.to_dict() for k, v in self.file_progress.items()},
-            "status": self.status
-        }
-
-    @classmethod
-    def from_dict(cls, data):
-        # Convert overall_eta from seconds back to timedelta
-        if 'overall_eta' in data:
-            data['overall_eta'] = timedelta(seconds=data['overall_eta'])
-
-        # Parse file_progress
-        if 'file_progress' in data:
-            data['file_progress'] = {
-                k: HFRepoFileProgressEvent.from_dict(v)
-                for k, v in data['file_progress'].items()
-            }
-
-        return cls(**data)
-
-HFRepoFileProgressCallback = Callable[[HFRepoFileProgressEvent], Coroutine[Any, Any, None]]
-HFRepoProgressCallback = Callable[[HFRepoProgressEvent], Coroutine[Any, Any, None]]
-
 @retry(
 @retry(
     stop=stop_after_attempt(5),
     stop=stop_after_attempt(5),
     wait=wait_exponential(multiplier=1, min=4, max=60),
     wait=wait_exponential(multiplier=1, min=4, max=60),
     retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError, aiohttp.ClientResponseError)),
     retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError, aiohttp.ClientResponseError)),
     reraise=True
     reraise=True
 )
 )
-async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[HFRepoFileProgressCallback] = None, use_range_request: bool = True):
+async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[RepoFileProgressCallback] = None, use_range_request: bool = True):
     base_url = f"https://huggingface.co/{repo_id}/resolve/{revision}/"
     base_url = f"https://huggingface.co/{repo_id}/resolve/{revision}/"
     url = urljoin(base_url, file_path)
     url = urljoin(base_url, file_path)
     local_path = os.path.join(save_directory, file_path)
     local_path = os.path.join(save_directory, file_path)
@@ -198,7 +122,7 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
         if downloaded_size == total_size:
         if downloaded_size == total_size:
             if DEBUG >= 2: print(f"File already downloaded: {file_path}")
             if DEBUG >= 2: print(f"File already downloaded: {file_path}")
             if progress_callback:
             if progress_callback:
-                await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
+                await progress_callback(RepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
             return
             return
 
 
         if response.status == 200:
         if response.status == 200:
@@ -221,7 +145,7 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
                 if downloaded_size == total_size:
                 if downloaded_size == total_size:
                     if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}")
                     if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}")
                     if progress_callback:
                     if progress_callback:
-                        await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
+                        await progress_callback(RepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
                     return
                     return
             except ValueError:
             except ValueError:
                 if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
                 if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
@@ -232,7 +156,7 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
         if downloaded_size == total_size:
         if downloaded_size == total_size:
             print(f"File already downloaded: {file_path}")
             print(f"File already downloaded: {file_path}")
             if progress_callback:
             if progress_callback:
-                await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
+                await progress_callback(RepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
             return
             return
 
 
         DOWNLOAD_CHUNK_SIZE = 32768
         DOWNLOAD_CHUNK_SIZE = 32768
@@ -249,10 +173,10 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
                     eta = timedelta(seconds=remaining_size / speed) if speed > 0 else timedelta(0)
                     eta = timedelta(seconds=remaining_size / speed) if speed > 0 else timedelta(0)
                     status = "in_progress" if downloaded_size < total_size else "complete"
                     status = "in_progress" if downloaded_size < total_size else "complete"
                     if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}")
                     if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}")
-                    await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
+                    await progress_callback(RepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
         if DEBUG >= 2: print(f"Downloaded: {file_path}")
         if DEBUG >= 2: print(f"Downloaded: {file_path}")
 
 
-async def download_all_files(repo_id: str, revision: str = "main", progress_callback: Optional[HFRepoProgressCallback] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None):
+async def download_repo_files(repo_id: str, revision: str = "main", progress_callback: Optional[RepoProgressCallback] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None) -> Path:
     repo_root = get_repo_root(repo_id)
     repo_root = get_repo_root(repo_id)
     refs_dir = repo_root / "refs"
     refs_dir = repo_root / "refs"
     snapshots_dir = repo_root / "snapshots"
     snapshots_dir = repo_root / "snapshots"
@@ -283,11 +207,11 @@ async def download_all_files(repo_id: str, revision: str = "main", progress_call
         filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
         filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
         total_files = len(filtered_file_list)
         total_files = len(filtered_file_list)
         total_bytes = sum(file["size"] for file in filtered_file_list)
         total_bytes = sum(file["size"] for file in filtered_file_list)
-        file_progress: Dict[str, HFRepoFileProgressEvent] = {file["path"]: HFRepoFileProgressEvent(file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started") for file in filtered_file_list}
+        file_progress: Dict[str, RepoFileProgressEvent] = {file["path"]: RepoFileProgressEvent(file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started") for file in filtered_file_list}
         start_time = datetime.now()
         start_time = datetime.now()
 
 
         async def download_with_progress(file_info, progress_state):
         async def download_with_progress(file_info, progress_state):
-            async def file_progress_callback(event: HFRepoFileProgressEvent):
+            async def file_progress_callback(event: RepoFileProgressEvent):
                 progress_state['downloaded_bytes'] += event.downloaded - file_progress[event.file_path].downloaded
                 progress_state['downloaded_bytes'] += event.downloaded - file_progress[event.file_path].downloaded
                 progress_state['downloaded_bytes_this_session'] += event.downloaded_this_session - file_progress[event.file_path].downloaded_this_session
                 progress_state['downloaded_bytes_this_session'] += event.downloaded_this_session - file_progress[event.file_path].downloaded_this_session
                 file_progress[event.file_path] = event
                 file_progress[event.file_path] = event
@@ -297,21 +221,60 @@ async def download_all_files(repo_id: str, revision: str = "main", progress_call
                     remaining_bytes = total_bytes - progress_state['downloaded_bytes']
                     remaining_bytes = total_bytes - progress_state['downloaded_bytes']
                     overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
                     overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
                     status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete"
                     status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete"
-                    await progress_callback(HFRepoProgressEvent(progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
+                    await progress_callback(RepoProgressEvent(progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
 
 
             await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
             await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
             progress_state['completed_files'] += 1
             progress_state['completed_files'] += 1
-            file_progress[file_info["path"]] = HFRepoFileProgressEvent(file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete")
+            file_progress[file_info["path"]] = RepoFileProgressEvent(file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete")
             if progress_callback:
             if progress_callback:
                 elapsed_time = (datetime.now() - start_time).total_seconds()
                 elapsed_time = (datetime.now() - start_time).total_seconds()
                 overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
                 overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
                 remaining_bytes = total_bytes - progress_state['downloaded_bytes']
                 remaining_bytes = total_bytes - progress_state['downloaded_bytes']
                 overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
                 overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
                 status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
                 status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
-                await progress_callback(HFRepoProgressEvent(progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
+                await progress_callback(RepoProgressEvent(progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
 
 
         progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0}
         progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0}
         tasks = [download_with_progress(file_info, progress_state) for file_info in filtered_file_list]
         tasks = [download_with_progress(file_info, progress_state) for file_info in filtered_file_list]
         await asyncio.gather(*tasks)
         await asyncio.gather(*tasks)
 
 
     return snapshot_dir
     return snapshot_dir
+
+async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[str, str]]:
+    """
+    Retrieve the weight map from the model.safetensors.index.json file.
+
+    Args:
+        repo_id (str): The Hugging Face repository ID.
+        revision (str): The revision of the repository to use.
+
+    Returns:
+        Optional[Dict[str, str]]: The weight map if it exists, otherwise None.
+    """
+
+    # Download the index file
+    await download_repo_files(
+        repo_id=repo_id,
+        revision=revision,
+        allow_patterns="model.safetensors.index.json"
+    )
+
+    # Check if the file exists
+    repo_root = get_repo_root(repo_id)
+    snapshot_dir = repo_root / "snapshots"
+    index_file = next(snapshot_dir.glob("*/model.safetensors.index.json"), None)
+
+    if index_file and index_file.exists():
+        with open(index_file, 'r') as f:
+            index_data = json.load(f)
+        return index_data.get("weight_map")
+
+    return None
+
+def extract_layer_num(tensor_name: str) -> Optional[int]:
+    # This is a simple example and might need to be adjusted based on the actual naming convention
+    parts = tensor_name.split('.')
+    for part in parts:
+        if part.isdigit():
+            return int(part)
+    return None

+ 76 - 0
exo/download/hf/hf_shard_download.py

@@ -0,0 +1,76 @@
+import asyncio
+from pathlib import Path
+from typing import Dict, List, Tuple
+from exo.inference.shard import Shard
+from exo.download.shard_download import ShardDownloader
+from exo.download.download_progress import RepoProgressEvent
+from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_repo_root, get_weight_map, extract_layer_num
+from exo.helpers import AsyncCallbackSystem
+
+class HFShardDownloader(ShardDownloader):
+    def __init__(self):
+        self.active_downloads: List[Tuple[Shard, asyncio.Task]] = []
+        self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
+
+    async def ensure_shard(self, shard: Shard) -> Path:
+        # Cancel any overlapping downloads
+        to_remove = []
+        for active_shard, task in self.active_downloads:
+            if shard.overlaps(active_shard):
+                task.cancel()
+                try:
+                    await task
+                except asyncio.CancelledError:
+                    pass  # This is expected when cancelling a task
+                to_remove.append((active_shard, task))
+
+        # Remove cancelled downloads from the list
+        for item in to_remove:
+            self.active_downloads.remove(item)
+
+        # Start new download
+        download_task = asyncio.create_task(self._download_shard(shard))
+        self.active_downloads.append((shard, download_task))
+
+        try:
+            return await download_task
+        finally:
+            # Ensure the task is removed even if an exception occurs
+            if (shard, download_task) in self.active_downloads:
+                self.active_downloads.remove((shard, download_task))
+
+    async def _download_shard(self, shard: Shard) -> Path:
+        async def wrapped_progress_callback(event: RepoProgressEvent):
+            self._on_progress.trigger_all(shard, event)
+
+        weight_map = await get_weight_map(shard.model_id)
+        allow_patterns = self._get_allow_patterns(weight_map, shard.start_layer, shard.end_layer)
+
+        return await download_repo_files(
+            repo_id=shard.model_id,
+            progress_callback=wrapped_progress_callback,
+            allow_patterns=allow_patterns
+        )
+
+    @staticmethod
+    def _get_allow_patterns(weight_map: Dict[str, str], start_layer: int, end_layer: int) -> List[str]:
+        default_patterns = [
+            "*.json",
+            "*.py",
+            "tokenizer.model",
+            "*.tiktoken",
+            "*.txt",
+        ]
+        shard_specific_patterns = []
+        if weight_map:
+            for tensor_name, filename in weight_map.items():
+                layer_num = extract_layer_num(tensor_name)
+                if layer_num is not None and start_layer <= layer_num <= end_layer:
+                    shard_specific_patterns.append(filename)
+        else:
+            shard_specific_patterns = ["*.safetensors"]
+        return list(set(default_patterns + shard_specific_patterns))  # Remove duplicates
+
+    @property
+    def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
+        return self._on_progress

+ 25 - 0
exo/download/shard_download.py

@@ -0,0 +1,25 @@
+from abc import ABC, abstractmethod
+from typing import Optional, Tuple
+from pathlib import Path
+from exo.inference.shard import Shard
+from exo.download.download_progress import RepoProgressEvent
+from exo.helpers import AsyncCallbackSystem
+
+class ShardDownloader(ABC):
+    @abstractmethod
+    async def ensure_shard(self, shard: Shard) -> Path:
+        """
+        Ensures that the shard is downloaded.
+        Does not allow multiple overlapping downloads at once.
+        If you try to download a Shard which overlaps a Shard that is already being downloaded,
+        the download will be cancelled and a new download will start.
+
+        Args:
+            shard (Shard): The shard to download.
+        """
+        pass
+
+    @property
+    @abstractmethod
+    def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
+        pass

+ 3 - 3
exo/helpers.py

@@ -31,17 +31,17 @@ def get_system_info():
   return "Non-Mac, non-Linux system"
   return "Non-Mac, non-Linux system"
 
 
 
 
-def get_inference_engine(inference_engine_name):
+def get_inference_engine(inference_engine_name, shard_downloader: 'ShardDownloader'):
   if inference_engine_name == "mlx":
   if inference_engine_name == "mlx":
     from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
     from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
 
 
-    return MLXDynamicShardInferenceEngine()
+    return MLXDynamicShardInferenceEngine(shard_downloader)
   elif inference_engine_name == "tinygrad":
   elif inference_engine_name == "tinygrad":
     from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
     from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
     import tinygrad.helpers
     import tinygrad.helpers
     tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
     tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
 
 
-    return TinygradDynamicShardInferenceEngine()
+    return TinygradDynamicShardInferenceEngine(shard_downloader)
   else:
   else:
     raise ValueError(f"Inference engine {inference_engine_name} not supported")
     raise ValueError(f"Inference engine {inference_engine_name} not supported")
 
 

+ 1 - 6
exo/inference/inference_engine.py

@@ -1,9 +1,8 @@
 import numpy as np
 import numpy as np
 
 
-from typing import Tuple, Optional, Callable, Coroutine, Any
+from typing import Tuple, Optional
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 from .shard import Shard
 from .shard import Shard
-from exo.inference.hf_helpers import HFRepoProgressEvent
 
 
 class InferenceEngine(ABC):
 class InferenceEngine(ABC):
   @abstractmethod
   @abstractmethod
@@ -13,7 +12,3 @@ class InferenceEngine(ABC):
   @abstractmethod
   @abstractmethod
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
     pass
     pass
-
-  @abstractmethod
-  def set_progress_callback(self, progress_callback: Callable[[HFRepoProgressEvent], Coroutine[Any, Any, None]]):
-    pass

+ 6 - 8
exo/inference/mlx/sharded_inference_engine.py

@@ -4,14 +4,14 @@ from ..inference_engine import InferenceEngine
 from .sharded_model import StatefulShardedModel
 from .sharded_model import StatefulShardedModel
 from .sharded_utils import load_shard, get_image_from_str
 from .sharded_utils import load_shard, get_image_from_str
 from ..shard import Shard
 from ..shard import Shard
-from typing import Optional, Callable
-from exo.inference.hf_helpers import HFRepoProgressCallback
+from typing import Optional
+from exo.download.shard_download import ShardDownloader
 
 
 
 
 class MLXDynamicShardInferenceEngine(InferenceEngine):
 class MLXDynamicShardInferenceEngine(InferenceEngine):
-  def __init__(self, progress_callback: Optional[HFRepoProgressCallback] = None):
+  def __init__(self, shard_downloader: ShardDownloader):
     self.shard = None
     self.shard = None
-    self.progress_callback = progress_callback
+    self.shard_downloader = shard_downloader
 
 
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
@@ -34,9 +34,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     if self.shard == shard:
     if self.shard == shard:
       return
       return
 
 
-    model_shard, self.tokenizer = await load_shard(shard.model_id, shard, progress_callback=self.progress_callback)
+    model_path = await self.shard_downloader.ensure_shard(shard)
+    model_shard, self.tokenizer = await load_shard(model_path, shard)
     self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
     self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
     self.shard = shard
     self.shard = shard
-
-  def set_progress_callback(self, progress_callback: HFRepoProgressCallback):
-    self.progress_callback = progress_callback

+ 1 - 175
exo/inference/mlx/sharded_utils.py

@@ -15,18 +15,12 @@ import base64
 
 
 import mlx.core as mx
 import mlx.core as mx
 import mlx.nn as nn
 import mlx.nn as nn
-from huggingface_hub import snapshot_download, list_repo_tree, get_paths_info
-from huggingface_hub.utils import filter_repo_objects
-from huggingface_hub.file_download import repo_folder_name
-from huggingface_hub.constants import HF_HUB_CACHE
-from huggingface_hub.utils._errors import RepositoryNotFoundError
 from transformers import AutoProcessor
 from transformers import AutoProcessor
 
 
 from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
 from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
 from mlx_lm.tuner.utils import apply_lora_layers
 from mlx_lm.tuner.utils import apply_lora_layers
 
 
 from exo import DEBUG
 from exo import DEBUG
-from exo.inference.hf_helpers import download_all_files, HFRepoProgressCallback
 from ..shard import Shard
 from ..shard import Shard
 
 
 
 
@@ -162,182 +156,14 @@ def load_model_shard(
   model.eval()
   model.eval()
   return model
   return model
 
 
-
-repo_id_safetensors_layers = {
-  "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit": {
-    "model.safetensors": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
-  },
-  "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit": {
-    "model-00001-of-00008.safetensors": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
-    "model-00002-of-00008.safetensors": [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
-    "model-00003-of-00008.safetensors": [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
-    "model-00004-of-00008.safetensors": [31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42],
-    "model-00005-of-00008.safetensors": [42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53],
-    "model-00006-of-00008.safetensors": [53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64],
-    "model-00007-of-00008.safetensors": [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75],
-    "model-00008-of-00008.safetensors": [75, 76, 77, 78, 79],
-  },
-  "mlx-community/Meta-Llama-3.1-405B-Instruct-4bit": {
-    "model-00001-of-00046.safetensors": [0, 1, 2],
-    "model-00002-of-00046.safetensors": [2, 3, 4, 5],
-    "model-00003-of-00046.safetensors": [5, 6, 7],
-    "model-00004-of-00046.safetensors": [8, 9, 10],
-    "model-00005-of-00046.safetensors": [10, 11, 12, 13],
-    "model-00006-of-00046.safetensors": [13, 14, 15, 16],
-    "model-00007-of-00046.safetensors": [16, 17, 18, 19],
-    "model-00008-of-00046.safetensors": [19, 20, 21],
-    "model-00009-of-00046.safetensors": [22, 23, 24],
-    "model-00010-of-00046.safetensors": [24, 25, 26, 27],
-    "model-00011-of-00046.safetensors": [27, 28, 29, 30],
-    "model-00012-of-00046.safetensors": [30, 31, 32, 33],
-    "model-00013-of-00046.safetensors": [33, 34, 35],
-    "model-00014-of-00046.safetensors": [36, 37, 38],
-    "model-00015-of-00046.safetensors": [38, 39, 40, 41],
-    "model-00016-of-00046.safetensors": [41, 42, 43, 44],
-    "model-00017-of-00046.safetensors": [44, 45, 46, 47],
-    "model-00018-of-00046.safetensors": [47, 48, 49],
-    "model-00019-of-00046.safetensors": [50, 51, 52],
-    "model-00020-of-00046.safetensors": [52, 53, 54, 55],
-    "model-00021-of-00046.safetensors": [55, 56, 57, 58],
-    "model-00022-of-00046.safetensors": [58, 59, 60, 61],
-    "model-00023-of-00046.safetensors": [61, 62, 63],
-    "model-00024-of-00046.safetensors": [64, 65, 66],
-    "model-00025-of-00046.safetensors": [66, 67, 68, 69],
-    "model-00026-of-00046.safetensors": [69, 70, 71, 72],
-    "model-00027-of-00046.safetensors": [72, 73, 74, 75],
-    "model-00028-of-00046.safetensors": [75, 76, 77],
-    "model-00029-of-00046.safetensors": [78, 79, 80],
-    "model-00030-of-00046.safetensors": [80, 81, 82, 83],
-    "model-00031-of-00046.safetensors": [83, 84, 85, 86],
-    "model-00032-of-00046.safetensors": [86, 87, 88, 89],
-    "model-00033-of-00046.safetensors": [89, 90, 91],
-    "model-00034-of-00046.safetensors": [92, 93, 94],
-    "model-00035-of-00046.safetensors": [94, 95, 96, 97],
-    "model-00036-of-00046.safetensors": [97, 98, 99, 100],
-    "model-00037-of-00046.safetensors": [100, 101, 102, 103],
-    "model-00038-of-00046.safetensors": [103, 104, 105],
-    "model-00039-of-00046.safetensors": [106, 107, 108],
-    "model-00040-of-00046.safetensors": [108, 109, 110, 111],
-    "model-00041-of-00046.safetensors": [111, 112, 113, 114],
-    "model-00042-of-00046.safetensors": [114, 115, 116, 117],
-    "model-00043-of-00046.safetensors": [117, 118, 119],
-    "model-00044-of-00046.safetensors": [120, 121, 122],
-    "model-00045-of-00046.safetensors": [122, 123, 124, 125],
-    "model-00046-of-00046.safetensors": [125]
-  },
-  "mlx-community/Mistral-Nemo-Instruct-2407-4bit": {
-    "model-00001-of-00002.safetensors": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32],
-    "model-00002-of-00002.safetensors": [32, 33, 34, 35, 36, 37, 38, 39],
-  },
-  "mlx-community/Mistral-Large-Instruct-2407-4bit": {
-    "model-00001-of-00014.safetensors": [0, 1, 2, 3, 4, 5, 6],
-    "model-00002-of-00014.safetensors": [6, 7, 8, 9, 10, 11, 12, 13],
-    "model-00003-of-00014.safetensors": [13, 14, 15, 16, 17, 18, 19, 20],
-    "model-00004-of-00014.safetensors": [20, 21, 22, 23, 24, 25, 26],
-    "model-00005-of-00014.safetensors": [27, 28, 29, 30, 31, 32, 33],
-    "model-00006-of-00014.safetensors": [33, 34, 35, 36, 37, 38, 39, 40],
-    "model-00007-of-00014.safetensors": [40, 41, 42, 43, 44, 45, 46, 47],
-    "model-00008-of-00014.safetensors": [47, 48, 49, 50, 51, 52, 53, 54],
-    "model-00009-of-00014.safetensors": [54, 55, 56, 57, 58, 59, 60],
-    "model-00010-of-00014.safetensors": [61, 62, 63, 64, 65, 66, 67],
-    "model-00011-of-00014.safetensors": [67, 68, 69, 70, 71, 72, 73, 74],
-    "model-00012-of-00014.safetensors": [74, 75, 76, 77, 78, 79, 80, 81],
-    "model-00013-of-00014.safetensors": [81, 82, 83, 84, 85, 86, 87],
-    "model-00014-of-00014.safetensors": [87]
-  },
-  "llava-hf/llava-1.5-7b-hf": {
-    "model-00001-of-00003.safetensors": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
-    "model-00002-of-00003.safetensors": [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22],
-    "model-00003-of-00003.safetensors": [22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
-  }
-}
-
-def get_safetensors_allow_patterns(repo_id: str, shard: Optional[Shard] = None):
-    return ["*.safetensors"] # TODO: enable this
-    if not shard:
-      return ["*.safetensors"]
-
-    allow_patterns = []
-    for repo_id, safetensors_layers in repo_id_safetensors_layers.items():
-        if repo_id == shard.model_id:
-            for safetensor, layers in safetensors_layers.items():
-                if any(shard.start_layer <= layer <= shard.end_layer for layer in layers):
-                    allow_patterns.append(safetensor)
-
-    return allow_patterns if len(allow_patterns) > 0 else ["*.safetensors"]
-
-async def get_model_path(path_or_hf_repo: str, shard: Optional[Shard] = None, revision: str = "main", progress_callback: Optional[HFRepoProgressCallback] = None) -> Path:
-  """
-  Ensures the model is available locally. If the path does not exist locally,
-  it is downloaded from the Hugging Face Hub.
-
-  Args:
-   path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
-   revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
-
-  Returns:
-   Path: The path to the model.
-  """
-  model_path = Path(path_or_hf_repo)
-  if not model_path.exists():
-    try:
-      model_path = Path(
-        await download_all_files(
-          repo_id=path_or_hf_repo,
-          revision=revision,
-          allow_patterns=[
-            "*.json",
-            "*.py",
-            "tokenizer.model",
-            "*.tiktoken",
-            "*.txt",
-          ] + get_safetensors_allow_patterns(path_or_hf_repo, shard),
-          progress_callback=progress_callback,
-        )
-      )
-    except RepositoryNotFoundError:
-      raise ModelNotFoundError(
-        f"Model not found for path or HF repo: {path_or_hf_repo}.\n"
-        "Please make sure you specified the local path or Hugging Face"
-        " repo id correctly.\nIf you are trying to access a private or"
-        " gated Hugging Face repo, make sure you are authenticated:\n"
-        "https://huggingface.co/docs/huggingface_hub/en/guides/cli#huggingface-cli-login"
-      ) from None
-  return model_path
-
-
 async def load_shard(
 async def load_shard(
-  path_or_hf_repo: str,
+  model_path: str,
   shard: Shard,
   shard: Shard,
   tokenizer_config={},
   tokenizer_config={},
   model_config={},
   model_config={},
   adapter_path: Optional[str] = None,
   adapter_path: Optional[str] = None,
   lazy: bool = False,
   lazy: bool = False,
-  progress_callback: Optional[HFRepoProgressCallback] = None,
 ) -> Tuple[nn.Module, TokenizerWrapper]:
 ) -> Tuple[nn.Module, TokenizerWrapper]:
-  """
-  Load the model and tokenizer from a given path or a huggingface repository.
-
-  Args:
-   path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
-   tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
-    Defaults to an empty dictionary.
-   model_config(dict, optional): Configuration parameters specifically for the model.
-    Defaults to an empty dictionary.
-   adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
-    to the model. Default: ``None``.
-   lazy (bool): If False eval the model parameters to make sure they are
-    loaded in memory before returning, otherwise they will be loaded
-    when needed. Default: ``False``
-  Returns:
-   Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
-
-  Raises:
-   FileNotFoundError: If config file or safetensors are not found.
-   ValueError: If model class or args class are not found.
-  """
-  model_path = await get_model_path(path_or_hf_repo, shard, progress_callback=progress_callback)
-
   model = load_model_shard(model_path, shard, lazy, model_config)
   model = load_model_shard(model_path, shard, lazy, model_config)
   if adapter_path is not None:
   if adapter_path is not None:
     model = apply_lora_layers(model, adapter_path)
     model = apply_lora_layers(model, adapter_path)

+ 9 - 0
exo/inference/shard.py

@@ -24,3 +24,12 @@ class Shard:
       "end_layer": self.end_layer,
       "end_layer": self.end_layer,
       "n_layers": self.n_layers,
       "n_layers": self.n_layers,
     }
     }
+
+  def overlaps(self, other: 'Shard') -> bool:
+    return shards_overlap(self, other)
+
+def shards_overlap(shard1: Shard, shard2: Shard) -> bool:
+  return (
+      shard1.model_id == shard2.model_id
+      and max(shard1.start_layer, shard2.start_layer) <= min(shard1.end_layer, shard2.end_layer)
+  )

+ 5 - 8
exo/inference/tinygrad/inference.py

@@ -1,5 +1,5 @@
 from pathlib import Path
 from pathlib import Path
-from typing import List, Optional, Union, Callable, Coroutine, Any
+from typing import List, Optional
 import json
 import json
 from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
 from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
 from tinygrad.nn.state import safe_load, torch_load, load_state_dict
 from tinygrad.nn.state import safe_load, torch_load, load_state_dict
@@ -8,7 +8,7 @@ from tinygrad.helpers import tqdm
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.inference_engine import InferenceEngine
 import numpy as np
 import numpy as np
-from exo.inference.hf_helpers import HFRepoProgressCallback, HFRepoProgressEvent, download_all_files
+from exo.download.shard_download import ShardDownloader
 
 
 MODEL_PARAMS = {
 MODEL_PARAMS = {
   "8B": {
   "8B": {
@@ -147,9 +147,9 @@ def prefill(model, toks, start_pos=0):
 
 
 
 
 class TinygradDynamicShardInferenceEngine(InferenceEngine):
 class TinygradDynamicShardInferenceEngine(InferenceEngine):
-  def __init__(self, progress_callback: Optional[HFRepoProgressCallback] = None):
+  def __init__(self, shard_downloader: ShardDownloader):
     self.shard = None
     self.shard = None
-    self.progress_callback = progress_callback
+    self.shard_downloader = shard_downloader
 
 
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     # TODO: we need to refactor models/llamaa to handle per-request-kv-cache. right now it's shared between requests.
     # TODO: we need to refactor models/llamaa to handle per-request-kv-cache. right now it's shared between requests.
@@ -188,7 +188,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     if self.shard == shard:
     if self.shard == shard:
       return
       return
 
 
-    model_path = await download_all_files(shard.model_id, progress_callback=self.progress_callback)
+    model_path = await self.shard_downloader.ensure_shard(shard)
     print(f"{model_path=}")
     print(f"{model_path=}")
     model = build_transformer(model_path, shard=shard, model_size="8B" if "8b" in shard.model_id else "70B" if "70b" in shard.model_id else "8B")
     model = build_transformer(model_path, shard=shard, model_size="8B" if "8b" in shard.model_id else "70B" if "70b" in shard.model_id else "8B")
     from transformers import AutoTokenizer
     from transformers import AutoTokenizer
@@ -197,6 +197,3 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     self.shard = shard
     self.shard = shard
     self.model = model
     self.model = model
     self.tokenizer = tokenizer
     self.tokenizer = tokenizer
-
-  def set_progress_callback(self, progress_callback: Callable[[HFRepoProgressEvent], Coroutine[Any, Any, None]]):
-    self.progress_callback = progress_callback

+ 8 - 8
exo/orchestration/standard_node.py

@@ -14,7 +14,7 @@ from exo.topology.partitioning_strategy import Partition, PartitioningStrategy,
 from exo import DEBUG
 from exo import DEBUG
 from exo.helpers import AsyncCallbackSystem
 from exo.helpers import AsyncCallbackSystem
 from exo.viz.topology_viz import TopologyViz
 from exo.viz.topology_viz import TopologyViz
-from exo.inference.hf_helpers import HFRepoProgressEvent
+from exo.download.hf.hf_helpers import RepoProgressEvent
 
 
 
 
 class StandardNode(Node):
 class StandardNode(Node):
@@ -25,7 +25,7 @@ class StandardNode(Node):
     inference_engine: InferenceEngine,
     inference_engine: InferenceEngine,
     discovery: Discovery,
     discovery: Discovery,
     partitioning_strategy: PartitioningStrategy = None,
     partitioning_strategy: PartitioningStrategy = None,
-    max_generate_tokens: int = 256,
+    max_generate_tokens: int = 1024,
     chatgpt_api_endpoint: Optional[str] = None,
     chatgpt_api_endpoint: Optional[str] = None,
     web_chat_url: Optional[str] = None,
     web_chat_url: Optional[str] = None,
     disable_tui: Optional[bool] = False,
     disable_tui: Optional[bool] = False,
@@ -44,6 +44,7 @@ class StandardNode(Node):
     self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
     self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
     self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
     self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
     self._on_opaque_status.register("node_status").on_next(self.on_node_status)
     self._on_opaque_status.register("node_status").on_next(self.on_node_status)
+    self.node_download_progress: Dict[str, RepoProgressEvent] = {}
 
 
   def on_node_status(self, request_id, opaque_status):
   def on_node_status(self, request_id, opaque_status):
     try:
     try:
@@ -57,14 +58,13 @@ class StandardNode(Node):
       download_progress = None
       download_progress = None
       if status_data.get("type", "") == "download_progress":
       if status_data.get("type", "") == "download_progress":
         if DEBUG >= 5: print(f"Download progress from {status_data.get('node_id')}: {status_data.get('progress')}")
         if DEBUG >= 5: print(f"Download progress from {status_data.get('node_id')}: {status_data.get('progress')}")
-        if status_data.get("node_id") == self.id:
-          download_progress = HFRepoProgressEvent.from_dict(status_data.get('progress'))
+        download_progress = RepoProgressEvent.from_dict(status_data.get('progress'))
+        self.node_download_progress[status_data.get('node_id')] = download_progress
       if self.topology_viz:
       if self.topology_viz:
-        self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), download_progress)
+        self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), self.id, self.node_download_progress)
     except Exception as e:
     except Exception as e:
       if DEBUG >= 1: print(f"Error updating visualization: {e}")
       if DEBUG >= 1: print(f"Error updating visualization: {e}")
-      traceback.print_exc()
-      pass
+      if DEBUG >= 1: traceback.print_exc()
 
 
   async def start(self, wait_for_peers: int = 0) -> None:
   async def start(self, wait_for_peers: int = 0) -> None:
     await self.server.start()
     await self.server.start()
@@ -347,7 +347,7 @@ class StandardNode(Node):
     next_topology.active_node_id = self.topology.active_node_id  # this is not so clean.
     next_topology.active_node_id = self.topology.active_node_id  # this is not so clean.
     self.topology = next_topology
     self.topology = next_topology
     if self.topology_viz:
     if self.topology_viz:
-      self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology))
+      self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), self.id)
     return next_topology
     return next_topology
 
 
   @property
   @property

+ 61 - 1
exo/viz/test_topology_viz.py

@@ -1,9 +1,57 @@
 import asyncio
 import asyncio
 import unittest
 import unittest
+from datetime import timedelta
 from exo.viz.topology_viz import TopologyViz
 from exo.viz.topology_viz import TopologyViz
 from exo.topology.topology import Topology
 from exo.topology.topology import Topology
 from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
 from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
 from exo.topology.partitioning_strategy import Partition
 from exo.topology.partitioning_strategy import Partition
+from exo.download.hf.hf_helpers import RepoProgressEvent, RepoFileProgressEvent
+
+
+def create_hf_repo_progress_event(
+    completed_files: int = 5,
+    total_files: int = 10,
+    downloaded_bytes: int = 500000000,
+    downloaded_bytes_this_session: int = 250000000,
+    total_bytes: int = 1000000000,
+    overall_speed: int = 5000000,
+    overall_eta: timedelta = timedelta(seconds=100),
+    file_progress: dict = None,
+    status: str = "in_progress"
+) -> RepoProgressEvent:
+    if file_progress is None:
+        file_progress = {
+            "file1.bin": RepoFileProgressEvent(
+                file_path="file1.bin",
+                downloaded=100000000,
+                downloaded_this_session=50000000,
+                total=200000000,
+                speed=1000000,
+                eta=timedelta(seconds=100),
+                status="in_progress"
+            ),
+            "file2.bin": RepoFileProgressEvent(
+                file_path="file2.bin",
+                downloaded=200000000,
+                downloaded_this_session=100000000,
+                total=200000000,
+                speed=2000000,
+                eta=timedelta(seconds=0),
+                status="complete"
+            )
+        }
+
+    return RepoProgressEvent(
+        completed_files=completed_files,
+        total_files=total_files,
+        downloaded_bytes=downloaded_bytes,
+        downloaded_bytes_this_session=downloaded_bytes_this_session,
+        total_bytes=total_bytes,
+        overall_speed=overall_speed,
+        overall_eta=overall_eta,
+        file_progress=file_progress,
+        status=status
+    )
 
 
 
 
 class TestNodeViz(unittest.IsolatedAsyncioTestCase):
 class TestNodeViz(unittest.IsolatedAsyncioTestCase):
@@ -30,7 +78,7 @@ class TestNodeViz(unittest.IsolatedAsyncioTestCase):
     await asyncio.sleep(2)  # Simulate running for a short time
     await asyncio.sleep(2)  # Simulate running for a short time
 
 
   async def test_layout_generation(self):
   async def test_layout_generation(self):
-    self.top_viz._generate_layout()
+    # self.top_viz._generate_layout()
     self.top_viz.refresh()
     self.top_viz.refresh()
     import time
     import time
 
 
@@ -43,6 +91,13 @@ class TestNodeViz(unittest.IsolatedAsyncioTestCase):
         Partition("node2", 0.4, 0.8),
         Partition("node2", 0.4, 0.8),
         Partition("node3", 0.8, 0.9),
         Partition("node3", 0.8, 0.9),
       ],
       ],
+      "node1",
+      {
+        "node1": create_hf_repo_progress_event(),
+        "node2": create_hf_repo_progress_event(),
+        "node3": create_hf_repo_progress_event(),
+        "node4": create_hf_repo_progress_event(),
+      },
     )
     )
     time.sleep(2)
     time.sleep(2)
     self.topology.active_node_id = "node3"
     self.topology.active_node_id = "node3"
@@ -54,6 +109,11 @@ class TestNodeViz(unittest.IsolatedAsyncioTestCase):
         Partition("node2", 0.5, 0.7),
         Partition("node2", 0.5, 0.7),
         Partition("node4", 0.7, 0.9),
         Partition("node4", 0.7, 0.9),
       ],
       ],
+      "node5",
+      {
+        "node1": create_hf_repo_progress_event(),
+        "node5": create_hf_repo_progress_event(),
+      },
     )
     )
     time.sleep(2)
     time.sleep(2)
 
 

+ 88 - 76
exo/viz/topology_viz.py

@@ -1,17 +1,17 @@
 import math
 import math
-from typing import List, Optional, Tuple
+from typing import List, Optional, Tuple, Dict
 from exo.helpers import exo_text, pretty_print_bytes, pretty_print_bytes_per_second
 from exo.helpers import exo_text, pretty_print_bytes, pretty_print_bytes_per_second
 from exo.topology.topology import Topology
 from exo.topology.topology import Topology
 from exo.topology.partitioning_strategy import Partition
 from exo.topology.partitioning_strategy import Partition
+from exo.download.hf.hf_helpers import RepoProgressEvent
 from rich.console import Console
 from rich.console import Console
 from rich.panel import Panel
 from rich.panel import Panel
 from rich.text import Text
 from rich.text import Text
 from rich.live import Live
 from rich.live import Live
 from rich.style import Style
 from rich.style import Style
-from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn
 from rich.table import Table
 from rich.table import Table
+from rich.layout import Layout
 from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
 from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
-from exo.inference.hf_helpers import HFRepoProgressEvent
 
 
 class TopologyViz:
 class TopologyViz:
   def __init__(self, chatgpt_api_endpoint: str = None, web_chat_url: str = None):
   def __init__(self, chatgpt_api_endpoint: str = None, web_chat_url: str = None):
@@ -19,60 +19,54 @@ class TopologyViz:
     self.web_chat_url = web_chat_url
     self.web_chat_url = web_chat_url
     self.topology = Topology()
     self.topology = Topology()
     self.partitions: List[Partition] = []
     self.partitions: List[Partition] = []
-    self.download_progress = None
+    self.node_id = None
+    self.node_download_progress: Dict[str, RepoProgressEvent] = {}
 
 
     self.console = Console()
     self.console = Console()
-    self.panel = Panel(self._generate_layout(), title="Exo Cluster (0 nodes)", border_style="bright_yellow")
-    self.live_panel = Live(self.panel, auto_refresh=False, console=self.console)
+    self.layout = Layout()
+    self.layout.split(
+      Layout(name="main"),
+      Layout(name="download", size=15)
+    )
+    self.main_panel = Panel(self._generate_main_layout(), title="Exo Cluster (0 nodes)", border_style="bright_yellow")
+    self.download_panel = Panel("", title="Download Progress", border_style="cyan")
+    self.layout["main"].update(self.main_panel)
+    self.layout["download"].update(self.download_panel)
+    self.live_panel = Live(self.layout, auto_refresh=False, console=self.console)
     self.live_panel.start()
     self.live_panel.start()
 
 
-  def update_visualization(self, topology: Topology, partitions: List[Partition], download_progress: HFRepoProgressEvent = None):
+  def update_visualization(self, topology: Topology, partitions: List[Partition], node_id: Optional[str] = None, node_download_progress: Dict[str, RepoProgressEvent] = {}):
     self.topology = topology
     self.topology = topology
     self.partitions = partitions
     self.partitions = partitions
-    self.download_progress = download_progress
+    self.node_id = node_id
+    if node_download_progress:
+      self.node_download_progress = node_download_progress
     self.refresh()
     self.refresh()
 
 
   def refresh(self):
   def refresh(self):
-    self.panel.renderable = self._generate_layout()
+    self.main_panel.renderable = self._generate_main_layout()
     # Update the panel title with the number of nodes and partitions
     # Update the panel title with the number of nodes and partitions
     node_count = len(self.topology.nodes)
     node_count = len(self.topology.nodes)
-    self.panel.title = f"Exo Cluster ({node_count} node{'s' if node_count != 1 else ''})"
-    self.live_panel.update(self.panel, refresh=True)
+    self.main_panel.title = f"Exo Cluster ({node_count} node{'s' if node_count != 1 else ''})"
 
 
-  def _generate_layout(self) -> str:
+    # Only show download_panel if there are in-progress downloads
+    if any(progress.status == "in_progress" for progress in self.node_download_progress.values()):
+      self.download_panel.renderable = self._generate_download_layout()
+      self.layout["download"].visible = True
+    else:
+      self.layout["download"].visible = False
+
+    self.live_panel.update(self.layout, refresh=True)
+
+  def _generate_main_layout(self) -> str:
     # Calculate visualization parameters
     # Calculate visualization parameters
     num_partitions = len(self.partitions)
     num_partitions = len(self.partitions)
-    radius_x = 30  # Increased horizontal radius
-    radius_y = 12  # Decreased vertical radius
-    center_x, center_y = 50, 28  # Centered horizontally and moved up slightly
+    radius_x = 30
+    radius_y = 12
+    center_x, center_y = 50, 24  # Increased center_y to add more space
 
 
     # Generate visualization
     # Generate visualization
-    visualization = [[" " for _ in range(100)] for _ in range(55)]  # Decreased height
-
-    # Draw download first so everything else is drawn on top
-    # If a download is in progress, show the download info summary
-    if self.download_progress and self.download_progress.status != "complete":
-        download_summary = _generate_download_summary(self.download_progress)
-        download_panel = Panel(
-            download_summary,
-            title="Download Progress",
-            border_style="cyan",
-            expand=False,
-            width=96,  # Further reduced to ensure it fits within the visualization
-            height=None  # Allow the panel to adjust its height based on content
-        )
-        console = Console(width=98, height=55)  # Reduced console width
-        with console.capture() as capture:
-            console.print(download_panel)
-        download_lines = capture.get().split('\n')
-        download_start_y = 15
-        panel_width = len(max(download_lines, key=len))
-        start_x = max(1, (100 - panel_width) // 2)  # Ensure start_x is at least 1 to avoid left border cut-off
-        for i, line in enumerate(download_lines):
-            for j, char in enumerate(line):
-                if 1 <= start_x + j < 99 and download_start_y + i < 55:  # Ensure we don't write to the rightmost column
-                    visualization[download_start_y + i][start_x + j] = char
-
+    visualization = [[" " for _ in range(100)] for _ in range(48)]  # Increased height to 48
 
 
     # Add exo_text at the top in bright yellow
     # Add exo_text at the top in bright yellow
     exo_lines = exo_text.split("\n")
     exo_lines = exo_text.split("\n")
@@ -80,7 +74,7 @@ class TopologyViz:
     max_line_length = max(len(line) for line in exo_lines)
     max_line_length = max(len(line) for line in exo_lines)
     for i, line in enumerate(exo_lines):
     for i, line in enumerate(exo_lines):
       centered_line = line.center(max_line_length)
       centered_line = line.center(max_line_length)
-      start_x = (100 - max_line_length) // 2 + 15  # Center the text plus empirical adjustment of 15
+      start_x = (100 - max_line_length) // 2 + 15
       colored_line = Text(centered_line, style=yellow_style)
       colored_line = Text(centered_line, style=yellow_style)
       for j, char in enumerate(str(colored_line)):
       for j, char in enumerate(str(colored_line)):
         if 0 <= start_x + j < 100 and i < len(visualization):
         if 0 <= start_x + j < 100 and i < len(visualization):
@@ -95,9 +89,9 @@ class TopologyViz:
 
 
     info_start_y = len(exo_lines) + 1
     info_start_y = len(exo_lines) + 1
     for i, line in enumerate(info_lines):
     for i, line in enumerate(info_lines):
-      start_x = (100 - len(line)) // 2 + 15  # Center the info lines plus empirical adjustment of 15
+      start_x = (100 - len(line)) // 2 + 15
       for j, char in enumerate(line):
       for j, char in enumerate(line):
-        if 0 <= start_x + j < 100 and info_start_y + i < 55:
+        if 0 <= start_x + j < 100 and info_start_y + i < 48:
           visualization[info_start_y + i][start_x + j] = char
           visualization[info_start_y + i][start_x + j] = char
 
 
     # Calculate total FLOPS and position on the bar
     # Calculate total FLOPS and position on the bar
@@ -105,13 +99,13 @@ class TopologyViz:
     bar_pos = (math.tanh(total_flops / 20 - 2) + 1) / 2
     bar_pos = (math.tanh(total_flops / 20 - 2) + 1) / 2
 
 
     # Add GPU poor/rich bar
     # Add GPU poor/rich bar
-    bar_width = 30  # Increased bar width
-    bar_start_x = (100 - bar_width) // 2  # Center the bar
-    bar_y = info_start_y + len(info_lines) + 1  # Position the bar below the info section with two cells of space
+    bar_width = 30
+    bar_start_x = (100 - bar_width) // 2
+    bar_y = info_start_y + len(info_lines) + 1
 
 
     # Create a gradient bar using emojis
     # Create a gradient bar using emojis
     gradient_bar = Text()
     gradient_bar = Text()
-    emojis = ["🟥", "🟧", "🟨", "🟩"]  # Red, Orange, Yellow, Green
+    emojis = ["🟥", "🟧", "🟨", "🟩"]
     for i in range(bar_width):
     for i in range(bar_width):
       emoji_index = min(int(i / (bar_width / len(emojis))), len(emojis) - 1)
       emoji_index = min(int(i / (bar_width / len(emojis))), len(emojis) - 1)
       gradient_bar.append(emojis[emoji_index])
       gradient_bar.append(emojis[emoji_index])
@@ -133,6 +127,9 @@ class TopologyViz:
     visualization[bar_y + 1][pos_x - len(flops_str) // 2 : pos_x + len(flops_str) // 2 + len(flops_str) % 2] = flops_str
     visualization[bar_y + 1][pos_x - len(flops_str) // 2 : pos_x + len(flops_str) // 2 + len(flops_str) % 2] = flops_str
     visualization[bar_y + 2][pos_x] = "▲"
     visualization[bar_y + 2][pos_x] = "▲"
 
 
+    # Add an extra empty line for spacing
+    bar_y += 4
+
     for i, partition in enumerate(self.partitions):
     for i, partition in enumerate(self.partitions):
       device_capabilities = self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES)
       device_capabilities = self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES)
 
 
@@ -140,11 +137,13 @@ class TopologyViz:
       x = int(center_x + radius_x * math.cos(angle))
       x = int(center_x + radius_x * math.cos(angle))
       y = int(center_y + radius_y * math.sin(angle))
       y = int(center_y + radius_y * math.sin(angle))
 
 
-      # Place node with different color for active node
+      # Place node with different color for active node and this node
       if partition.node_id == self.topology.active_node_id:
       if partition.node_id == self.topology.active_node_id:
-        visualization[y][x] = "🔴"  # Red circle for active node
+        visualization[y][x] = "🔴"
+      elif partition.node_id == self.node_id:
+        visualization[y][x] = "🟢"
       else:
       else:
-        visualization[y][x] = "🔵"  # Blue circle for inactive nodes
+        visualization[y][x] = "🔵"
 
 
       # Place node info (model, memory, TFLOPS, partition) on three lines
       # Place node info (model, memory, TFLOPS, partition) on three lines
       node_info = [
       node_info = [
@@ -154,28 +153,27 @@ class TopologyViz:
       ]
       ]
 
 
       # Calculate info position based on angle
       # Calculate info position based on angle
-      info_distance_x = radius_x + 6  # Increased horizontal distance
-      info_distance_y = radius_y + 3  # Decreased vertical distance
+      info_distance_x = radius_x + 6
+      info_distance_y = radius_y + 3
       info_x = int(center_x + info_distance_x * math.cos(angle))
       info_x = int(center_x + info_distance_x * math.cos(angle))
       info_y = int(center_y + info_distance_y * math.sin(angle))
       info_y = int(center_y + info_distance_y * math.sin(angle))
 
 
       # Adjust text position to avoid overwriting the node icon and prevent cutoff
       # Adjust text position to avoid overwriting the node icon and prevent cutoff
-      if info_x < x:  # Text is to the left of the node
+      if info_x < x:
         info_x = max(0, x - len(max(node_info, key=len)) - 1)
         info_x = max(0, x - len(max(node_info, key=len)) - 1)
-      elif info_x > x:  # Text is to the right of the node
+      elif info_x > x:
         info_x = min(99 - len(max(node_info, key=len)), info_x)
         info_x = min(99 - len(max(node_info, key=len)), info_x)
 
 
       # Adjust for top and bottom nodes
       # Adjust for top and bottom nodes
-      if 5 * math.pi / 4 < angle < 7 * math.pi / 4:  # Node is near the top
-        info_x += 4  # Shift text slightly to the right
-      elif math.pi / 4 < angle < 3 * math.pi / 4:  # Node is near the bottom
-        info_x += 3  # Shift text slightly to the right
-        info_y -= 2  # Move text up by two cells
+      if 5 * math.pi / 4 < angle < 7 * math.pi / 4:
+        info_x += 4
+      elif math.pi / 4 < angle < 3 * math.pi / 4:
+        info_x += 3
+        info_y -= 2
 
 
       for j, line in enumerate(node_info):
       for j, line in enumerate(node_info):
         for k, char in enumerate(line):
         for k, char in enumerate(line):
-          if 0 <= info_y + j < 55 and 0 <= info_x + k < 100:  # Updated height check
-            # Ensure we're not overwriting the node icon
+          if 0 <= info_y + j < 48 and 0 <= info_x + k < 100:
             if info_y + j != y or info_x + k != x:
             if info_y + j != y or info_x + k != x:
               visualization[info_y + j][info_x + k] = char
               visualization[info_y + j][info_x + k] = char
 
 
@@ -190,33 +188,47 @@ class TopologyViz:
       for step in range(1, steps):
       for step in range(1, steps):
         line_x = int(x + (next_x - x) * step / steps)
         line_x = int(x + (next_x - x) * step / steps)
         line_y = int(y + (next_y - y) * step / steps)
         line_y = int(y + (next_y - y) * step / steps)
-        if 0 <= line_y < 55 and 0 <= line_x < 100:  # Updated height check
+        if 0 <= line_y < 48 and 0 <= line_x < 100:
           visualization[line_y][line_x] = "-"
           visualization[line_y][line_x] = "-"
 
 
     # Convert to string
     # Convert to string
     return "\n".join("".join(str(char) for char in row) for row in visualization)
     return "\n".join("".join(str(char) for char in row) for row in visualization)
 
 
-def _generate_download_summary(download_progress) -> Table:
+  def _generate_download_layout(self) -> Table:
     summary = Table(show_header=False, box=None, padding=(0, 1))
     summary = Table(show_header=False, box=None, padding=(0, 1))
     summary.add_column("Info", style="cyan", no_wrap=True)
     summary.add_column("Info", style="cyan", no_wrap=True)
     summary.add_column("Progress", style="cyan", no_wrap=True)
     summary.add_column("Progress", style="cyan", no_wrap=True)
     summary.add_column("Percentage", style="cyan", no_wrap=True)
     summary.add_column("Percentage", style="cyan", no_wrap=True)
 
 
-    title = f"Downloading model ({download_progress.completed_files}/{download_progress.total_files}):"
-    summary.add_row(Text(title, style="bold"))
-    progress_info = f"{pretty_print_bytes(download_progress.downloaded_bytes)} / {pretty_print_bytes(download_progress.total_bytes)} ({pretty_print_bytes_per_second(download_progress.overall_speed)})"
-    summary.add_row(progress_info)
+    # Current node download progress
+    if self.node_id in self.node_download_progress:
+        download_progress = self.node_download_progress[self.node_id]
+        title = f"Downloading model ({download_progress.completed_files}/{download_progress.total_files}):"
+        summary.add_row(Text(title, style="bold"))
+        progress_info = f"{pretty_print_bytes(download_progress.downloaded_bytes)} / {pretty_print_bytes(download_progress.total_bytes)} ({pretty_print_bytes_per_second(download_progress.overall_speed)})"
+        summary.add_row(progress_info)
+
+        eta_info = f"ETA: {download_progress.overall_eta}"
+        summary.add_row(eta_info)
+
+        summary.add_row("")  # Empty row for spacing
 
 
-    eta_info = f"ETA: {download_progress.overall_eta}"
-    summary.add_row(eta_info)
+        for file_path, file_progress in download_progress.file_progress.items():
+            if file_progress.status != "complete":
+                progress = int(file_progress.downloaded / file_progress.total * 20)
+                bar = f"[{'=' * progress}{' ' * (20 - progress)}]"
+                percentage = f"{file_progress.downloaded / file_progress.total * 100:.0f}%"
+                summary.add_row(Text(file_path[:20], style="cyan"), bar, percentage)
 
 
     summary.add_row("")  # Empty row for spacing
     summary.add_row("")  # Empty row for spacing
 
 
-    for file_path, file_progress in download_progress.file_progress.items():
-      if file_progress.status != "complete":
-        progress = int(file_progress.downloaded / file_progress.total * 20)  # Increased bar width
-        bar = f"[{'=' * progress}{' ' * (20 - progress)}]"
-        percentage = f"{file_progress.downloaded / file_progress.total * 100:.0f}%"
-        summary.add_row(Text(file_path[:20], style="cyan"), bar, percentage)  # Increased file path length
+    # Other nodes download progress summary
+    summary.add_row(Text("Other Nodes Download Progress:", style="bold"))
+    for node_id, progress in self.node_download_progress.items():
+        if node_id != self.node_id:
+            truncated_id = node_id[:8] + "..." if len(node_id) > 8 else node_id
+            percentage = progress.downloaded_bytes / progress.total_bytes * 100 if progress.total_bytes > 0 else 0
+            speed = pretty_print_bytes_per_second(progress.overall_speed)
+            summary.add_row(f"{truncated_id}: {percentage:.1f}% ({speed})")
 
 
     return summary
     return summary

+ 2 - 2
extra/download_hf.py

@@ -1,6 +1,6 @@
 import argparse
 import argparse
 import asyncio
 import asyncio
-from exo.inference.hf_helpers import download_all_files, HFRepoProgressEvent, HFRepoFileProgressEvent
+from exo.download.hf.hf_helpers import download_all_files, RepoProgressEvent
 
 
 DEFAULT_ALLOW_PATTERNS = [
 DEFAULT_ALLOW_PATTERNS = [
     "*.json",
     "*.json",
@@ -23,7 +23,7 @@ DEFAULT_IGNORE_PATTERNS = [
 ]
 ]
 
 
 async def main(repo_id, revision="main", allow_patterns=None, ignore_patterns=None):
 async def main(repo_id, revision="main", allow_patterns=None, ignore_patterns=None):
-    async def progress_callback(event: HFRepoProgressEvent):
+    async def progress_callback(event: RepoProgressEvent):
         print(f"Overall Progress: {event.completed_files}/{event.total_files} files, {event.downloaded_bytes}/{event.total_bytes} bytes")
         print(f"Overall Progress: {event.completed_files}/{event.total_files} files, {event.downloaded_bytes}/{event.total_bytes} bytes")
         print(f"Estimated time remaining: {event.overall_eta}")
         print(f"Estimated time remaining: {event.overall_eta}")
         print("File Progress:")
         print("File Progress:")

+ 16 - 5
main.py

@@ -2,12 +2,14 @@ import argparse
 import asyncio
 import asyncio
 import signal
 import signal
 import json
 import json
-import uuid
+import time
 from exo.orchestration.standard_node import StandardNode
 from exo.orchestration.standard_node import StandardNode
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.grpc.grpc_discovery import GRPCDiscovery
 from exo.networking.grpc.grpc_discovery import GRPCDiscovery
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 from exo.api import ChatGPTAPI
 from exo.api import ChatGPTAPI
+from exo.download.shard_download import ShardDownloader
+from exo.download.hf.hf_shard_download import HFShardDownloader
 from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_inference_engine, get_system_info, get_or_create_node_id
 from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_inference_engine, get_system_info, get_or_create_node_id
 
 
 # parse args
 # parse args
@@ -22,7 +24,7 @@ parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery
 parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
 parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
 parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
 parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
 parser.add_argument("--chatgpt-api-response-timeout-secs", type=int, default=90, help="ChatGPT API response timeout in seconds")
 parser.add_argument("--chatgpt-api-response-timeout-secs", type=int, default=90, help="ChatGPT API response timeout in seconds")
-parser.add_argument("--max-generate-tokens", type=int, default=256, help="Max tokens to generate in each request")
+parser.add_argument("--max-generate-tokens", type=int, default=1024, help="Max tokens to generate in each request")
 parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use")
 parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use")
 parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
 parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
 args = parser.parse_args()
 args = parser.parse_args()
@@ -32,9 +34,10 @@ print_yellow_exo()
 system_info = get_system_info()
 system_info = get_system_info()
 print(f"Detected system: {system_info}")
 print(f"Detected system: {system_info}")
 
 
+shard_downloader: ShardDownloader = HFShardDownloader()
 inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
 inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
-inference_engine = get_inference_engine(inference_engine_name)
-print(f"Using inference engine: {inference_engine.__class__.__name__}")
+inference_engine = get_inference_engine(inference_engine_name, shard_downloader)
+print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")
 
 
 if args.node_port is None:
 if args.node_port is None:
     args.node_port = find_available_port(args.node_host)
     args.node_port = find_available_port(args.node_host)
@@ -60,7 +63,15 @@ node.on_token.register("main_log").on_next(lambda _, tokens, __: print(inference
 if args.prometheus_client_port:
 if args.prometheus_client_port:
     from exo.stats.metrics import start_metrics_server
     from exo.stats.metrics import start_metrics_server
     start_metrics_server(node, args.prometheus_client_port)
     start_metrics_server(node, args.prometheus_client_port)
-inference_engine.set_progress_callback(lambda event: asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()}))))
+
+last_broadcast_time = 0
+def throttled_broadcast(shard, event):
+    global last_broadcast_time
+    current_time = time.time()
+    if current_time - last_broadcast_time >= 0.1:
+        last_broadcast_time = current_time
+        asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
+shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
 
 
 async def shutdown(signal, loop):
 async def shutdown(signal, loop):
     """Gracefully shutdown the server and close the asyncio loop."""
     """Gracefully shutdown the server and close the asyncio loop."""