Browse Source

Merge pull request #124 from exo-explore/refactor_model_download

Refactor model download, refactor tinygrad
Alex Cheema 9 months ago
parent
commit
9e78c42b4b

+ 4 - 4
.circleci/config.yml

@@ -17,11 +17,11 @@ commands:
             source env/bin/activate
 
             # Start first instance
-            DEBUG_DISCOVERY=9 DEBUG=9 python3 main.py --inference-engine <<parameters.inference_engine>> --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 --chatgpt-api-response-timeout-secs 900 > output1.log 2>&1 &
+            HF_HOME="$(pwd)/.hf_cache_node1" DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --inference-engine <<parameters.inference_engine>> --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 --chatgpt-api-response-timeout-secs 900 > output1.log 2>&1 &
             PID1=$!
 
             # Start second instance
-            DEBUG_DISCOVERY=9 DEBUG=9 python3 main.py --inference-engine <<parameters.inference_engine>> --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 --chatgpt-api-response-timeout-secs 900 > output2.log 2>&1 &
+            HF_HOME="$(pwd)/.hf_cache_node2" DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --inference-engine <<parameters.inference_engine>> --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 --chatgpt-api-response-timeout-secs 900 > output2.log 2>&1 &
             PID2=$!
 
             # Wait for discovery
@@ -132,9 +132,9 @@ jobs:
           name: Run discovery integration test
           command: |
             source env/bin/activate
-            DEBUG_DISCOVERY=9 DEBUG=9 python3 main.py --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 > output1.log 2>&1 &
+            DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 > output1.log 2>&1 &
             PID1=$!
-            DEBUG_DISCOVERY=9 DEBUG=9 python3 main.py --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 > output2.log 2>&1 &
+            DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 > output2.log 2>&1 &
             PID2=$!
             sleep 10
             kill $PID1 $PID2

+ 18 - 38
exo/api/chatgpt_api.py

@@ -7,6 +7,7 @@ from transformers import AutoTokenizer, AutoProcessor
 from typing import List, Literal, Union, Dict
 from aiohttp import web
 import aiohttp_cors
+import traceback
 from exo import DEBUG, VERSION
 from exo.helpers import terminal_link, PrefixDict
 from exo.inference.shard import Shard
@@ -16,20 +17,22 @@ shard_mappings = {
   ### llama
   "llama-3.1-8b": {
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32),
   },
   "llama-3.1-70b": {
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B", start_layer=0, end_layer=0, n_layers=80),
   },
   "llama-3.1-405b": {
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),
   },
   "llama-3-8b": {
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-8b-sfr", start_layer=0, end_layer=0, n_layers=32),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32),
   },
   "llama-3-70b": {
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-70b-sfr", start_layer=0, end_layer=0, n_layers=80),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80),
   },
   ### mistral
   "mistral-nemo": {
@@ -76,18 +79,10 @@ class ChatCompletionRequest:
         }
 
 
-def resolve_tinygrad_tokenizer(model_id: str):
-  if model_id == "llama3-8b-sfr":
-    return AutoTokenizer.from_pretrained("TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R")
-  elif model_id == "llama3-70b-sfr":
-    return AutoTokenizer.from_pretrained("TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R")
-  else:
-    raise ValueError(f"tinygrad doesnt currently support arbitrary model downloading. unsupported model: {model_id}")
-
 
 async def resolve_tokenizer(model_id: str):
   try:
-    if DEBUG >= 2: print(f"Trying AutoProcessor for {model_id}")
+    if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id}")
     processor = AutoProcessor.from_pretrained(model_id, use_fast=False)
     if not hasattr(processor, 'eos_token_id'):
       processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
@@ -97,33 +92,17 @@ async def resolve_tokenizer(model_id: str):
       processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
     return processor
   except Exception as e:
-    if DEBUG >= 2: print(f"Failed to load processor for {model_id}. Error: {e}")
-    import traceback
-
-    if DEBUG >= 2: print(traceback.format_exc())
+    if DEBUG >= 4: print(f"Failed to load processor for {model_id}. Error: {e}")
+    if DEBUG >= 4: print(traceback.format_exc())
 
   try:
-    if DEBUG >= 2: print(f"Trying AutoTokenizer for {model_id}")
+    if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id}")
     return AutoTokenizer.from_pretrained(model_id)
   except Exception as e:
-    if DEBUG >= 2: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer. Error: {e}")
-    import traceback
+    if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer. Error: {e}")
+    if DEBUG >= 4: print(traceback.format_exc())
 
-    if DEBUG >= 2: print(traceback.format_exc())
-
-  try:
-    if DEBUG >= 2: print(f"Trying tinygrad tokenizer for {model_id}")
-    return resolve_tinygrad_tokenizer(model_id)
-  except Exception as e:
-    if DEBUG >= 2: print(f"Failed again to load tokenizer for {model_id}. Falling back to mlx tokenizer. Error: {e}")
-    import traceback
-
-    if DEBUG >= 2: print(traceback.format_exc())
-
-  if DEBUG >= 2: print(f"Trying mlx tokenizer for {model_id}")
-  from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
-
-  return load_tokenizer(await get_model_path(model_id))
+  raise ValueError(f"[TODO] Unsupported model: {model_id}")
 
 
 def generate_completion(
@@ -326,10 +305,7 @@ class ChatGPTAPI:
     try:
       await self.node.process_prompt(shard, prompt, image_str, request_id=request_id)
     except Exception as e:
-      if DEBUG >= 2:
-        import traceback
-
-        traceback.print_exc()
+      if DEBUG >= 2: traceback.print_exc()
       return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
 
     try:
@@ -370,7 +346,11 @@ class ChatGPTAPI:
             "chat.completion",
           )
           if DEBUG >= 2: print(f"Streaming completion: {completion}")
-          await response.write(f"data: {json.dumps(completion)}\n\n".encode())
+          try:
+            await response.write(f"data: {json.dumps(completion)}\n\n".encode())
+          except Exception as e:
+            if DEBUG >= 2: print(f"Error streaming completion: {e}")
+            if DEBUG >= 2: traceback.print_exc()
 
         def on_result(_request_id: str, tokens: List[int], is_finished: bool):
           self.stream_tasks[request_id] = asyncio.create_task(stream_result(request_id, tokens, is_finished))

+ 82 - 0
exo/download/download_progress.py

@@ -0,0 +1,82 @@
+from typing import Dict, Callable, Coroutine, Any, Literal
+from dataclasses import dataclass
+from datetime import timedelta
+
+@dataclass
+class RepoFileProgressEvent:
+    repo_id: str
+    repo_revision: str
+    file_path: str
+    downloaded: int
+    downloaded_this_session: int
+    total: int
+    speed: int
+    eta: timedelta
+    status: Literal["not_started", "in_progress", "complete"]
+
+    def to_dict(self):
+        return {
+            "repo_id": self.repo_id,
+            "repo_revision": self.repo_revision,
+            "file_path": self.file_path,
+            "downloaded": self.downloaded,
+            "downloaded_this_session": self.downloaded_this_session,
+            "total": self.total,
+            "speed": self.speed,
+            "eta": self.eta.total_seconds(),
+            "status": self.status
+        }
+
+    @classmethod
+    def from_dict(cls, data):
+        # Convert eta from seconds back to timedelta
+        if 'eta' in data:
+            data['eta'] = timedelta(seconds=data['eta'])
+        return cls(**data)
+
+@dataclass
+class RepoProgressEvent:
+    repo_id: str
+    repo_revision: str
+    completed_files: int
+    total_files: int
+    downloaded_bytes: int
+    downloaded_bytes_this_session: int
+    total_bytes: int
+    overall_speed: int
+    overall_eta: timedelta
+    file_progress: Dict[str, RepoFileProgressEvent]
+    status: Literal["not_started", "in_progress", "complete"]
+
+    def to_dict(self):
+        return {
+            "repo_id": self.repo_id,
+            "repo_revision": self.repo_revision,
+            "completed_files": self.completed_files,
+            "total_files": self.total_files,
+            "downloaded_bytes": self.downloaded_bytes,
+            "downloaded_bytes_this_session": self.downloaded_bytes_this_session,
+            "total_bytes": self.total_bytes,
+            "overall_speed": self.overall_speed,
+            "overall_eta": self.overall_eta.total_seconds(),
+            "file_progress": {k: v.to_dict() for k, v in self.file_progress.items()},
+            "status": self.status
+        }
+
+    @classmethod
+    def from_dict(cls, data):
+        # Convert overall_eta from seconds back to timedelta
+        if 'overall_eta' in data:
+            data['overall_eta'] = timedelta(seconds=data['overall_eta'])
+
+        # Parse file_progress
+        if 'file_progress' in data:
+            data['file_progress'] = {
+                k: RepoFileProgressEvent.from_dict(v)
+                for k, v in data['file_progress'].items()
+            }
+
+        return cls(**data)
+
+RepoFileProgressCallback = Callable[[RepoFileProgressEvent], Coroutine[Any, Any, None]]
+RepoProgressCallback = Callable[[RepoProgressEvent], Coroutine[Any, Any, None]]

+ 305 - 0
exo/download/hf/hf_helpers.py

@@ -0,0 +1,305 @@
+import asyncio
+import aiohttp
+import json
+import os
+from urllib.parse import urljoin
+from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
+from datetime import datetime, timedelta
+from fnmatch import fnmatch
+from pathlib import Path
+from typing import Generator, Iterable, TypeVar, TypedDict
+from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
+from exo.helpers import DEBUG
+from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback
+from exo.inference.shard import Shard
+
+T = TypeVar("T")
+def filter_repo_objects(
+    items: Iterable[T],
+    *,
+    allow_patterns: Optional[Union[List[str], str]] = None,
+    ignore_patterns: Optional[Union[List[str], str]] = None,
+    key: Optional[Callable[[T], str]] = None,
+) -> Generator[T, None, None]:
+    if isinstance(allow_patterns, str):
+        allow_patterns = [allow_patterns]
+    if isinstance(ignore_patterns, str):
+        ignore_patterns = [ignore_patterns]
+    if allow_patterns is not None:
+        allow_patterns = [_add_wildcard_to_directories(p) for p in allow_patterns]
+    if ignore_patterns is not None:
+        ignore_patterns = [_add_wildcard_to_directories(p) for p in ignore_patterns]
+
+    if key is None:
+        def _identity(item: T) -> str:
+            if isinstance(item, str):
+                return item
+            if isinstance(item, Path):
+                return str(item)
+            raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.")
+        key = _identity
+
+    for item in items:
+        path = key(item)
+        if allow_patterns is not None and not any(fnmatch(path, r) for r in allow_patterns):
+            continue
+        if ignore_patterns is not None and any(fnmatch(path, r) for r in ignore_patterns):
+            continue
+        yield item
+
+def _add_wildcard_to_directories(pattern: str) -> str:
+    if pattern[-1] == "/":
+        return pattern + "*"
+    return pattern
+
+def get_hf_home() -> Path:
+    """Get the Hugging Face home directory."""
+    return Path(os.environ.get("HF_HOME", Path.home() / ".cache" / "huggingface"))
+
+def get_hf_token():
+    """Retrieve the Hugging Face token from the user's HF_HOME directory."""
+    token_path = get_hf_home() / "token"
+    if token_path.exists():
+        return token_path.read_text().strip()
+    return None
+
+def get_auth_headers():
+    """Get authentication headers if a token is available."""
+    token = get_hf_token()
+    if token:
+        return {"Authorization": f"Bearer {token}"}
+    return {}
+
+def get_repo_root(repo_id: str) -> Path:
+    """Get the root directory for a given repo ID in the Hugging Face cache."""
+    sanitized_repo_id = repo_id.replace("/", "--")
+    return get_hf_home() / "hub" / f"models--{sanitized_repo_id}"
+
+async def fetch_file_list(session, repo_id, revision, path=""):
+    api_url = f"https://huggingface.co/api/models/{repo_id}/tree/{revision}"
+    url = f"{api_url}/{path}" if path else api_url
+
+    headers = get_auth_headers()
+    async with session.get(url, headers=headers) as response:
+        if response.status == 200:
+            data = await response.json()
+            files = []
+            for item in data:
+                if item["type"] == "file":
+                    files.append({"path": item["path"], "size": item["size"]})
+                elif item["type"] == "directory":
+                    subfiles = await fetch_file_list(session, repo_id, revision, item["path"])
+                    files.extend(subfiles)
+            return files
+        else:
+            raise Exception(f"Failed to fetch file list: {response.status}")
+
+
+@retry(
+    stop=stop_after_attempt(5),
+    wait=wait_exponential(multiplier=1, min=4, max=60),
+    retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError, aiohttp.ClientResponseError)),
+    reraise=True
+)
+async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[RepoFileProgressCallback] = None, use_range_request: bool = True):
+    base_url = f"https://huggingface.co/{repo_id}/resolve/{revision}/"
+    url = urljoin(base_url, file_path)
+    local_path = os.path.join(save_directory, file_path)
+
+    os.makedirs(os.path.dirname(local_path), exist_ok=True)
+
+    # Check if file already exists and get its size
+    local_file_size = os.path.getsize(local_path) if os.path.exists(local_path) else 0
+
+    headers = get_auth_headers()
+    if use_range_request:
+        headers["Range"] = f"bytes={local_file_size}-"
+
+    async with session.get(url, headers=headers) as response:
+        total_size = int(response.headers.get('Content-Length', 0))
+        downloaded_size = local_file_size
+        downloaded_this_session = 0
+        mode = 'ab' if use_range_request else 'wb'
+        if downloaded_size == total_size:
+            if DEBUG >= 2: print(f"File already downloaded: {file_path}")
+            if progress_callback:
+                await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
+            return
+
+        if response.status == 200:
+            # File doesn't support range requests or we're not using them, start from beginning
+            mode = 'wb'
+            downloaded_size = 0
+        elif response.status == 206:
+            # Partial content, resume download
+            content_range = response.headers.get('Content-Range', '')
+            try:
+                total_size = int(content_range.split('/')[-1])
+            except ValueError:
+                if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
+                return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
+        elif response.status == 416:
+            # Range not satisfiable, get the actual file size
+            content_range = response.headers.get('Content-Range', '')
+            try:
+                total_size = int(content_range.split('/')[-1])
+                if downloaded_size == total_size:
+                    if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}")
+                    if progress_callback:
+                        await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
+                    return
+            except ValueError:
+                if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
+                return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
+        else:
+            raise aiohttp.ClientResponseError(response.request_info, response.history, status=response.status, message=f"Failed to download {file_path}: {response.status}")
+
+        if downloaded_size == total_size:
+            print(f"File already downloaded: {file_path}")
+            if progress_callback:
+                await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
+            return
+
+        DOWNLOAD_CHUNK_SIZE = 32768
+        start_time = datetime.now()
+        with open(local_path, mode) as f:
+            async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
+                f.write(chunk)
+                downloaded_size += len(chunk)
+                downloaded_this_session += len(chunk)
+                if progress_callback and total_size:
+                    elapsed_time = (datetime.now() - start_time).total_seconds()
+                    speed = int(downloaded_this_session / elapsed_time) if elapsed_time > 0 else 0
+                    remaining_size = total_size - downloaded_size
+                    eta = timedelta(seconds=remaining_size / speed) if speed > 0 else timedelta(0)
+                    status = "in_progress" if downloaded_size < total_size else "complete"
+                    if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}")
+                    await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
+        if DEBUG >= 2: print(f"Downloaded: {file_path}")
+
+async def download_repo_files(repo_id: str, revision: str = "main", progress_callback: Optional[RepoProgressCallback] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None) -> Path:
+    repo_root = get_repo_root(repo_id)
+    refs_dir = repo_root / "refs"
+    snapshots_dir = repo_root / "snapshots"
+
+    # Ensure directories exist
+    refs_dir.mkdir(parents=True, exist_ok=True)
+    snapshots_dir.mkdir(parents=True, exist_ok=True)
+
+    async with aiohttp.ClientSession() as session:
+        # Fetch the commit hash for the given revision
+        api_url = f"https://huggingface.co/api/models/{repo_id}/revision/{revision}"
+        headers = get_auth_headers()
+        async with session.get(api_url, headers=headers) as response:
+            if response.status != 200:
+                raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
+            revision_info = await response.json()
+            commit_hash = revision_info['sha']
+
+        # Write the commit hash to the refs file
+        refs_file = refs_dir / revision
+        refs_file.write_text(commit_hash)
+
+        # Set up the snapshot directory
+        snapshot_dir = snapshots_dir / commit_hash
+        snapshot_dir.mkdir(exist_ok=True)
+
+        file_list = await fetch_file_list(session, repo_id, revision)
+        filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
+        total_files = len(filtered_file_list)
+        total_bytes = sum(file["size"] for file in filtered_file_list)
+        file_progress: Dict[str, RepoFileProgressEvent] = {file["path"]: RepoFileProgressEvent(repo_id, revision, file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started") for file in filtered_file_list}
+        start_time = datetime.now()
+
+        async def download_with_progress(file_info, progress_state):
+            async def file_progress_callback(event: RepoFileProgressEvent):
+                progress_state['downloaded_bytes'] += event.downloaded - file_progress[event.file_path].downloaded
+                progress_state['downloaded_bytes_this_session'] += event.downloaded_this_session - file_progress[event.file_path].downloaded_this_session
+                file_progress[event.file_path] = event
+                if progress_callback:
+                    elapsed_time = (datetime.now() - start_time).total_seconds()
+                    overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
+                    remaining_bytes = total_bytes - progress_state['downloaded_bytes']
+                    overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
+                    status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete"
+                    await progress_callback(RepoProgressEvent(repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
+
+            await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
+            progress_state['completed_files'] += 1
+            file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete")
+            if progress_callback:
+                elapsed_time = (datetime.now() - start_time).total_seconds()
+                overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
+                remaining_bytes = total_bytes - progress_state['downloaded_bytes']
+                overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
+                status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
+                await progress_callback(RepoProgressEvent(repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
+
+        progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0}
+        tasks = [download_with_progress(file_info, progress_state) for file_info in filtered_file_list]
+        await asyncio.gather(*tasks)
+
+    return snapshot_dir
+
+async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[str, str]]:
+    """
+    Retrieve the weight map from the model.safetensors.index.json file.
+
+    Args:
+        repo_id (str): The Hugging Face repository ID.
+        revision (str): The revision of the repository to use.
+
+    Returns:
+        Optional[Dict[str, str]]: The weight map if it exists, otherwise None.
+    """
+
+    # Download the index file
+    await download_repo_files(
+        repo_id=repo_id,
+        revision=revision,
+        allow_patterns="model.safetensors.index.json"
+    )
+
+    # Check if the file exists
+    repo_root = get_repo_root(repo_id)
+    snapshot_dir = repo_root / "snapshots"
+    index_file = next(snapshot_dir.glob("*/model.safetensors.index.json"), None)
+
+    if index_file and index_file.exists():
+        with open(index_file, 'r') as f:
+            index_data = json.load(f)
+        return index_data.get("weight_map")
+
+    return None
+
+def extract_layer_num(tensor_name: str) -> Optional[int]:
+    # This is a simple example and might need to be adjusted based on the actual naming convention
+    parts = tensor_name.split('.')
+    for part in parts:
+        if part.isdigit():
+            return int(part)
+    return None
+
+
+def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
+    default_patterns = [
+        "*.json",
+        "*.py",
+        "tokenizer.model",
+        "*.tiktoken",
+        "*.txt",
+    ]
+    shard_specific_patterns = []
+    if weight_map:
+        for tensor_name, filename in weight_map.items():
+            layer_num = extract_layer_num(tensor_name)
+            if layer_num is not None and shard.start_layer <= layer_num <= shard.end_layer:
+                shard_specific_patterns.append(filename)
+        sorted_file_names = sorted(weight_map.values())
+        if shard.is_first_layer():
+            shard_specific_patterns.append(sorted_file_names[0])
+        elif shard.is_last_layer():
+            shard_specific_patterns.append(sorted_file_names[-1])
+    else:
+        shard_specific_patterns = ["*.safetensors"]
+    return list(set(default_patterns + shard_specific_patterns))  # Remove duplicates

+ 70 - 0
exo/download/hf/hf_shard_download.py

@@ -0,0 +1,70 @@
+import asyncio
+import traceback
+from pathlib import Path
+from typing import Dict, List, Tuple
+from exo.inference.shard import Shard
+from exo.download.shard_download import ShardDownloader
+from exo.download.download_progress import RepoProgressEvent
+from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_weight_map, get_allow_patterns
+from exo.helpers import AsyncCallbackSystem, DEBUG
+
+class HFShardDownloader(ShardDownloader):
+    def __init__(self):
+        self.active_downloads: Dict[Shard, asyncio.Task] = {}
+        self.completed_downloads: Dict[Shard, Path] = {}
+        self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
+
+    async def ensure_shard(self, shard: Shard) -> Path:
+        if shard in self.completed_downloads:
+            return self.completed_downloads[shard]
+
+        # If a download on this shard is already in progress, keep that one
+        for active_shard in self.active_downloads:
+            if active_shard == shard:
+                if DEBUG >= 2: print(f"Download already in progress for {shard}. Keeping that one.")
+                return await self.active_downloads[shard]
+
+        # Cancel any downloads for this model_id on a different shard
+        existing_active_shards = [active_shard for active_shard in self.active_downloads.keys() if active_shard.model_id == shard.model_id]
+        for active_shard in existing_active_shards:
+            if DEBUG >= 2: print(f"Cancelling download for {active_shard} (replacing with {shard})")
+            task = self.active_downloads[active_shard]
+            task.cancel()
+            try:
+                await task
+            except asyncio.CancelledError:
+                pass  # This is expected when cancelling a task
+            except Exception as e:
+                if DEBUG >= 2: print(f"Error in cancelling download {active_shard}: {e}")
+                traceback.print_exc()
+        self.active_downloads = {active_shard: task for active_shard, task in self.active_downloads.items() if active_shard.model_id != shard.model_id}
+
+        # Start new download
+        download_task = asyncio.create_task(self._download_shard(shard))
+        self.active_downloads[shard] = download_task
+        try:
+            path = await download_task
+            self.completed_downloads[shard] = path
+            return path
+        finally:
+            # Ensure the task is removed even if an exception occurs
+            print(f"Removing download task for {shard}: {shard in self.active_downloads}")
+            if shard in self.active_downloads:
+                self.active_downloads.pop(shard)
+
+    async def _download_shard(self, shard: Shard) -> Path:
+        async def wrapped_progress_callback(event: RepoProgressEvent):
+            self._on_progress.trigger_all(shard, event)
+
+        weight_map = await get_weight_map(shard.model_id)
+        allow_patterns = get_allow_patterns(weight_map, shard)
+
+        return await download_repo_files(
+            repo_id=shard.model_id,
+            progress_callback=wrapped_progress_callback,
+            allow_patterns=allow_patterns
+        )
+
+    @property
+    def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
+        return self._on_progress

+ 25 - 0
exo/download/shard_download.py

@@ -0,0 +1,25 @@
+from abc import ABC, abstractmethod
+from typing import Optional, Tuple
+from pathlib import Path
+from exo.inference.shard import Shard
+from exo.download.download_progress import RepoProgressEvent
+from exo.helpers import AsyncCallbackSystem
+
+class ShardDownloader(ABC):
+    @abstractmethod
+    async def ensure_shard(self, shard: Shard) -> Path:
+        """
+        Ensures that the shard is downloaded.
+        Does not allow multiple overlapping downloads at once.
+        If you try to download a Shard which overlaps a Shard that is already being downloaded,
+        the download will be cancelled and a new download will start.
+
+        Args:
+            shard (Shard): The shard to download.
+        """
+        pass
+
+    @property
+    @abstractmethod
+    def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
+        pass

+ 27 - 3
exo/helpers.py

@@ -31,17 +31,17 @@ def get_system_info():
   return "Non-Mac, non-Linux system"
 
 
-def get_inference_engine(inference_engine_name):
+def get_inference_engine(inference_engine_name, shard_downloader: 'ShardDownloader'):
   if inference_engine_name == "mlx":
     from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
 
-    return MLXDynamicShardInferenceEngine()
+    return MLXDynamicShardInferenceEngine(shard_downloader)
   elif inference_engine_name == "tinygrad":
     from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
     import tinygrad.helpers
     tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
 
-    return TinygradDynamicShardInferenceEngine()
+    return TinygradDynamicShardInferenceEngine(shard_downloader)
   else:
     raise ValueError(f"Inference engine {inference_engine_name} not supported")
 
@@ -201,3 +201,27 @@ def get_or_create_node_id():
     except Exception as e:
         if DEBUG >= 2: print(f"Unexpected error creating node_id: {e}")
         return str(uuid.uuid4())
+
+def pretty_print_bytes(size_in_bytes: int) -> str:
+    if size_in_bytes < 1024:
+        return f"{size_in_bytes} B"
+    elif size_in_bytes < 1024 ** 2:
+        return f"{size_in_bytes / 1024:.2f} KB"
+    elif size_in_bytes < 1024 ** 3:
+        return f"{size_in_bytes / (1024 ** 2):.2f} MB"
+    elif size_in_bytes < 1024 ** 4:
+        return f"{size_in_bytes / (1024 ** 3):.2f} GB"
+    else:
+        return f"{size_in_bytes / (1024 ** 4):.2f} TB"
+
+def pretty_print_bytes_per_second(bytes_per_second: int) -> str:
+    if bytes_per_second < 1024:
+        return f"{bytes_per_second} B/s"
+    elif bytes_per_second < 1024 ** 2:
+        return f"{bytes_per_second / 1024:.2f} KB/s"
+    elif bytes_per_second < 1024 ** 3:
+        return f"{bytes_per_second / (1024 ** 2):.2f} MB/s"
+    elif bytes_per_second < 1024 ** 4:
+        return f"{bytes_per_second / (1024 ** 3):.2f} GB/s"
+    else:
+        return f"{bytes_per_second / (1024 ** 4):.2f} TB/s"

+ 1 - 6
exo/inference/inference_engine.py

@@ -1,10 +1,9 @@
 import numpy as np
 
-from typing import Tuple, Optional, Callable
+from typing import Tuple, Optional
 from abc import ABC, abstractmethod
 from .shard import Shard
 
-
 class InferenceEngine(ABC):
   @abstractmethod
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
@@ -13,7 +12,3 @@ class InferenceEngine(ABC):
   @abstractmethod
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
     pass
-
-  @abstractmethod
-  def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]):
-    pass

+ 6 - 7
exo/inference/mlx/sharded_inference_engine.py

@@ -4,13 +4,14 @@ from ..inference_engine import InferenceEngine
 from .sharded_model import StatefulShardedModel
 from .sharded_utils import load_shard, get_image_from_str
 from ..shard import Shard
-from typing import Optional, Callable
+from typing import Optional
+from exo.download.shard_download import ShardDownloader
 
 
 class MLXDynamicShardInferenceEngine(InferenceEngine):
-  def __init__(self, on_download_progress: Callable[[int, int], None] = None):
+  def __init__(self, shard_downloader: ShardDownloader):
     self.shard = None
-    self.on_download_progress = on_download_progress
+    self.shard_downloader = shard_downloader
 
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     await self.ensure_shard(shard)
@@ -33,9 +34,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     if self.shard == shard:
       return
 
-    model_shard, self.tokenizer = await load_shard(shard.model_id, shard, on_download_progress=self.on_download_progress)
+    model_path = await self.shard_downloader.ensure_shard(shard)
+    model_shard, self.tokenizer = await load_shard(model_path, shard)
     self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
     self.shard = shard
-
-  def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]):
-    self.on_download_progress = on_download_progress

+ 2 - 223
exo/inference/mlx/sharded_utils.py

@@ -12,22 +12,15 @@ from typing import Optional, Tuple, Union, List, Callable
 from PIL import Image
 from io import BytesIO
 import base64
-import os
-import concurrent.futures
 
-from exo import DEBUG
 import mlx.core as mx
 import mlx.nn as nn
-from huggingface_hub import snapshot_download, list_repo_tree, get_paths_info
-from huggingface_hub.utils import filter_repo_objects
-from huggingface_hub.file_download import repo_folder_name
-from huggingface_hub.constants import HF_HUB_CACHE
-from huggingface_hub.utils._errors import RepositoryNotFoundError
 from transformers import AutoProcessor
 
 from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
 from mlx_lm.tuner.utils import apply_lora_layers
 
+from exo import DEBUG
 from ..shard import Shard
 
 
@@ -163,228 +156,14 @@ def load_model_shard(
   model.eval()
   return model
 
-
-async def get_repo_size(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None):
-  it = await asyncio.to_thread(list_repo_tree, repo_id, revision=revision, repo_type=repo_type)
-  files = list(filter_repo_objects(it, allow_patterns=allow_patterns, key=lambda f: f.path))
-  return sum(file.size for file in files if hasattr(file, "size") and file.size is not None)
-
-async def monitor_progress(dir, total_size, print_progress=False, on_progress: Callable[[int, int], None] = None):
-    while True:
-      try:
-        await asyncio.sleep(0.1)
-        current_size = sum(os.path.getsize(os.path.join(root, file))
-                            for root, _, files in os.walk(dir)
-                            for file in files)
-        progress = min(current_size / total_size * 100, 100)
-        if print_progress:
-          print(f"\rProgress: {progress:.2f}% ({current_size}/{total_size} bytes)", end="", flush=True)
-        if on_progress:
-          on_progress(current_size, total_size)
-        if progress >= 100:
-          if print_progress:
-            print("\nDownload complete!")
-          break
-      except Exception as e:
-        print(f"Error monitoring progress: {e}")
-
-async def download_repo(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None):
-    with concurrent.futures.ThreadPoolExecutor() as pool:
-        return await asyncio.get_event_loop().run_in_executor(
-            pool,
-            partial(snapshot_download, repo_id=repo_id, revision=revision, allow_patterns=allow_patterns, repo_type=repo_type)
-        )
-
-async def download_async_with_progress(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None, on_progress: Callable[[int, int], None] = None):
-  storage_folder = os.path.join(HF_HUB_CACHE, repo_folder_name(repo_id=repo_id, repo_type="model"))
-  # os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'
-  # os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
-
-  total_size = await get_repo_size(repo_id)
-
-  # Create tasks for download and progress checking
-  download_task = asyncio.create_task(download_repo(repo_id, revision=revision, allow_patterns=allow_patterns, repo_type=repo_type))
-  progress_task = asyncio.create_task(monitor_progress(storage_folder, total_size, on_progress=on_progress))
-
-  # Wait for both tasks to complete
-  result = await asyncio.gather(download_task, progress_task, return_exceptions=True)
-  return result[0]  # Return the result from download_task
-
-repo_id_safetensors_layers = {
-  "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit": {
-    "model.safetensors": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
-  },
-  "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit": {
-    "model-00001-of-00008.safetensors": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
-    "model-00002-of-00008.safetensors": [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
-    "model-00003-of-00008.safetensors": [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
-    "model-00004-of-00008.safetensors": [31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42],
-    "model-00005-of-00008.safetensors": [42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53],
-    "model-00006-of-00008.safetensors": [53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64],
-    "model-00007-of-00008.safetensors": [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75],
-    "model-00008-of-00008.safetensors": [75, 76, 77, 78, 79],
-  },
-  "mlx-community/Meta-Llama-3.1-405B-Instruct-4bit": {
-    "model-00001-of-00046.safetensors": [0, 1, 2],
-    "model-00002-of-00046.safetensors": [2, 3, 4, 5],
-    "model-00003-of-00046.safetensors": [5, 6, 7],
-    "model-00004-of-00046.safetensors": [8, 9, 10],
-    "model-00005-of-00046.safetensors": [10, 11, 12, 13],
-    "model-00006-of-00046.safetensors": [13, 14, 15, 16],
-    "model-00007-of-00046.safetensors": [16, 17, 18, 19],
-    "model-00008-of-00046.safetensors": [19, 20, 21],
-    "model-00009-of-00046.safetensors": [22, 23, 24],
-    "model-00010-of-00046.safetensors": [24, 25, 26, 27],
-    "model-00011-of-00046.safetensors": [27, 28, 29, 30],
-    "model-00012-of-00046.safetensors": [30, 31, 32, 33],
-    "model-00013-of-00046.safetensors": [33, 34, 35],
-    "model-00014-of-00046.safetensors": [36, 37, 38],
-    "model-00015-of-00046.safetensors": [38, 39, 40, 41],
-    "model-00016-of-00046.safetensors": [41, 42, 43, 44],
-    "model-00017-of-00046.safetensors": [44, 45, 46, 47],
-    "model-00018-of-00046.safetensors": [47, 48, 49],
-    "model-00019-of-00046.safetensors": [50, 51, 52],
-    "model-00020-of-00046.safetensors": [52, 53, 54, 55],
-    "model-00021-of-00046.safetensors": [55, 56, 57, 58],
-    "model-00022-of-00046.safetensors": [58, 59, 60, 61],
-    "model-00023-of-00046.safetensors": [61, 62, 63],
-    "model-00024-of-00046.safetensors": [64, 65, 66],
-    "model-00025-of-00046.safetensors": [66, 67, 68, 69],
-    "model-00026-of-00046.safetensors": [69, 70, 71, 72],
-    "model-00027-of-00046.safetensors": [72, 73, 74, 75],
-    "model-00028-of-00046.safetensors": [75, 76, 77],
-    "model-00029-of-00046.safetensors": [78, 79, 80],
-    "model-00030-of-00046.safetensors": [80, 81, 82, 83],
-    "model-00031-of-00046.safetensors": [83, 84, 85, 86],
-    "model-00032-of-00046.safetensors": [86, 87, 88, 89],
-    "model-00033-of-00046.safetensors": [89, 90, 91],
-    "model-00034-of-00046.safetensors": [92, 93, 94],
-    "model-00035-of-00046.safetensors": [94, 95, 96, 97],
-    "model-00036-of-00046.safetensors": [97, 98, 99, 100],
-    "model-00037-of-00046.safetensors": [100, 101, 102, 103],
-    "model-00038-of-00046.safetensors": [103, 104, 105],
-    "model-00039-of-00046.safetensors": [106, 107, 108],
-    "model-00040-of-00046.safetensors": [108, 109, 110, 111],
-    "model-00041-of-00046.safetensors": [111, 112, 113, 114],
-    "model-00042-of-00046.safetensors": [114, 115, 116, 117],
-    "model-00043-of-00046.safetensors": [117, 118, 119],
-    "model-00044-of-00046.safetensors": [120, 121, 122],
-    "model-00045-of-00046.safetensors": [122, 123, 124, 125],
-    "model-00046-of-00046.safetensors": [125]
-  },
-  "mlx-community/Mistral-Nemo-Instruct-2407-4bit": {
-    "model-00001-of-00002.safetensors": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32],
-    "model-00002-of-00002.safetensors": [32, 33, 34, 35, 36, 37, 38, 39],
-  },
-  "mlx-community/Mistral-Large-Instruct-2407-4bit": {
-    "model-00001-of-00014.safetensors": [0, 1, 2, 3, 4, 5, 6],
-    "model-00002-of-00014.safetensors": [6, 7, 8, 9, 10, 11, 12, 13],
-    "model-00003-of-00014.safetensors": [13, 14, 15, 16, 17, 18, 19, 20],
-    "model-00004-of-00014.safetensors": [20, 21, 22, 23, 24, 25, 26],
-    "model-00005-of-00014.safetensors": [27, 28, 29, 30, 31, 32, 33],
-    "model-00006-of-00014.safetensors": [33, 34, 35, 36, 37, 38, 39, 40],
-    "model-00007-of-00014.safetensors": [40, 41, 42, 43, 44, 45, 46, 47],
-    "model-00008-of-00014.safetensors": [47, 48, 49, 50, 51, 52, 53, 54],
-    "model-00009-of-00014.safetensors": [54, 55, 56, 57, 58, 59, 60],
-    "model-00010-of-00014.safetensors": [61, 62, 63, 64, 65, 66, 67],
-    "model-00011-of-00014.safetensors": [67, 68, 69, 70, 71, 72, 73, 74],
-    "model-00012-of-00014.safetensors": [74, 75, 76, 77, 78, 79, 80, 81],
-    "model-00013-of-00014.safetensors": [81, 82, 83, 84, 85, 86, 87],
-    "model-00014-of-00014.safetensors": [87]
-  },
-  "llava-hf/llava-1.5-7b-hf": {
-    "model-00001-of-00003.safetensors": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
-    "model-00002-of-00003.safetensors": [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22],
-    "model-00003-of-00003.safetensors": [22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
-  }
-}
-
-def get_safetensors_allow_patterns(repo_id: str, shard: Optional[Shard] = None):
-    return ["*.safetensors"] # TODO: enable this
-    if not shard:
-      return ["*.safetensors"]
-
-    allow_patterns = []
-    for repo_id, safetensors_layers in repo_id_safetensors_layers.items():
-        if repo_id == shard.model_id:
-            for safetensor, layers in safetensors_layers.items():
-                if any(shard.start_layer <= layer <= shard.end_layer for layer in layers):
-                    allow_patterns.append(safetensor)
-
-    return allow_patterns if len(allow_patterns) > 0 else ["*.safetensors"]
-
-async def get_model_path(path_or_hf_repo: str, shard: Optional[Shard] = None, revision: Optional[str] = None, on_download_progress: Callable[[int, int], None] = None) -> Path:
-  """
-  Ensures the model is available locally. If the path does not exist locally,
-  it is downloaded from the Hugging Face Hub.
-
-  Args:
-   path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
-   revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
-
-  Returns:
-   Path: The path to the model.
-  """
-  model_path = Path(path_or_hf_repo)
-  if not model_path.exists():
-    try:
-      model_path = Path(
-        await download_async_with_progress(
-          repo_id=path_or_hf_repo,
-          revision=revision,
-          allow_patterns=[
-            "*.json",
-            "*.py",
-            "tokenizer.model",
-            "*.tiktoken",
-            "*.txt",
-          ] + get_safetensors_allow_patterns(path_or_hf_repo, shard),
-          on_progress=on_download_progress,
-        )
-      )
-    except RepositoryNotFoundError:
-      raise ModelNotFoundError(
-        f"Model not found for path or HF repo: {path_or_hf_repo}.\n"
-        "Please make sure you specified the local path or Hugging Face"
-        " repo id correctly.\nIf you are trying to access a private or"
-        " gated Hugging Face repo, make sure you are authenticated:\n"
-        "https://huggingface.co/docs/huggingface_hub/en/guides/cli#huggingface-cli-login"
-      ) from None
-  return model_path
-
-
 async def load_shard(
-  path_or_hf_repo: str,
+  model_path: str,
   shard: Shard,
   tokenizer_config={},
   model_config={},
   adapter_path: Optional[str] = None,
   lazy: bool = False,
-  on_download_progress: Callable[[int, int], None] = None,
 ) -> Tuple[nn.Module, TokenizerWrapper]:
-  """
-  Load the model and tokenizer from a given path or a huggingface repository.
-
-  Args:
-   path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
-   tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
-    Defaults to an empty dictionary.
-   model_config(dict, optional): Configuration parameters specifically for the model.
-    Defaults to an empty dictionary.
-   adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
-    to the model. Default: ``None``.
-   lazy (bool): If False eval the model parameters to make sure they are
-    loaded in memory before returning, otherwise they will be loaded
-    when needed. Default: ``False``
-  Returns:
-   Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
-
-  Raises:
-   FileNotFoundError: If config file or safetensors are not found.
-   ValueError: If model class or args class are not found.
-  """
-  model_path = await get_model_path(path_or_hf_repo, shard, on_download_progress=on_download_progress)
-
   model = load_model_shard(model_path, shard, lazy, model_config)
   if adapter_path is not None:
     model = apply_lora_layers(model, adapter_path)

+ 17 - 2
exo/inference/shard.py

@@ -1,13 +1,16 @@
-from dataclasses import dataclass
+from dataclasses import dataclass, field
 
 
-@dataclass
+@dataclass(frozen=True)
 class Shard:
   model_id: str
   start_layer: int
   end_layer: int
   n_layers: int
 
+  def __hash__(self):
+    return hash((self.model_id, self.start_layer, self.end_layer, self.n_layers))
+
   def is_first_layer(self) -> bool:
     return self.start_layer == 0
 
@@ -24,3 +27,15 @@ class Shard:
       "end_layer": self.end_layer,
       "n_layers": self.n_layers,
     }
+
+  def from_dict(data: dict) -> 'Shard':
+    return Shard(**data)
+
+  def overlaps(self, other: 'Shard') -> bool:
+    return shards_overlap(self, other)
+
+def shards_overlap(shard1: Shard, shard2: Shard) -> bool:
+  return (
+      shard1.model_id == shard2.model_id
+      and max(shard1.start_layer, shard2.start_layer) <= min(shard1.end_layer, shard2.end_layer)
+  )

+ 8 - 5
exo/inference/test_inference_engine.py

@@ -1,4 +1,6 @@
+from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
 from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
+from exo.download.hf.hf_shard_download import HFShardDownloader
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.shard import Shard
 import asyncio
@@ -9,6 +11,7 @@ import numpy as np
 async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str):
   prompt = "In a single word only, what is the last name of the current president of the USA?"
   resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
+
   next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor(
     "A",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
@@ -42,15 +45,15 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
 
 asyncio.run(
   test_inference_engine(
-    MLXDynamicShardInferenceEngine(),
-    MLXDynamicShardInferenceEngine(),
+    MLXDynamicShardInferenceEngine(HFShardDownloader()),
+    MLXDynamicShardInferenceEngine(HFShardDownloader()),
     "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
   )
 )
 
 # TODO: Need more memory or a smaller model
 # asyncio.run(test_inference_engine(
-#     TinygradDynamicShardInferenceEngine(),
-#     TinygradDynamicShardInferenceEngine(),
-#     "llama3-8b-sfr",
+#     TinygradDynamicShardInferenceEngine(HFShardDownloader()),
+#     TinygradDynamicShardInferenceEngine(HFShardDownloader()),
+#     "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R",
 # ))

+ 52 - 221
exo/inference/tinygrad/inference.py

@@ -1,266 +1,97 @@
-import asyncio
-from functools import partial
 from pathlib import Path
-from typing import List, Optional, Union, Callable
+from typing import List
 import json
-import tiktoken
-from tiktoken.load import load_tiktoken_bpe
 from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
-from tinygrad.nn.state import safe_load, torch_load, load_state_dict
-from tinygrad import Tensor, nn, Context, GlobalCounters
-from tinygrad.helpers import DEBUG, tqdm, _cache_dir, fetch
 from exo.inference.shard import Shard
+from tinygrad.nn.state import safe_load, torch_load, load_state_dict
+from tinygrad import Tensor, dtypes, nn, Context
+from transformers import AutoTokenizer
 from exo.inference.inference_engine import InferenceEngine
+from typing import Optional, Tuple
 import numpy as np
-import os
+from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
+from exo.download.shard_download import ShardDownloader
 
+Tensor.no_grad = True
+# default settings
+TEMPERATURE = 0.85
+TOP_K = 25
+TOP_P = 0.9
+ALPHA_F = 0.1
+ALPHA_P = 0.0
 MODEL_PARAMS = {
   "8B": {
-    "args": {
-      "dim": 4096,
-      "n_heads": 32,
-      "n_kv_heads": 8,
-      "n_layers": 32,
-      "norm_eps": 1e-5,
-      "rope_theta": 500000,
-      "vocab_size": 128256,
-      "hidden_dim": 14336,
-    },
-    "files": 1,
+    "args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336},
+    "files": 1
   },
   "70B": {
-    "args": {
-      "dim": 8192,
-      "n_heads": 64,
-      "n_kv_heads": 8,
-      "n_layers": 80,
-      "norm_eps": 1e-5,
-      "rope_theta": 500000,
-      "vocab_size": 128256,
-      "hidden_dim": 28672,
-    },
-    "files": 8,
-  },
+    "args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256,  "hidden_dim": 28672},
+    "files": 8
+  }
 }
 
-
-
-# **** helper functions ****
-async def fetch_async(
-  url: str,
-  name: Optional[Union[Path, str]] = None,
-  subdir: Optional[str] = None,
-  allow_caching=not os.getenv("DISABLE_HTTP_CACHE"),
-) -> Path:
-  func = partial(fetch, url, name, subdir, allow_caching)
-  return await asyncio.get_event_loop().run_in_executor(None, func)
-
-
-def concat_weights(models, device=None):
-  def convert(name) -> Tensor:
-    disk_tensors: List[Tensor] = [model[name] for model in models]
-    if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1:
-      return disk_tensors[0].to(device=device)
-    axis = 1 if name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0
-    lazy_tensors = [data.to(device=device) for data in disk_tensors]
-    return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
-
-  return {name: convert(name) for name in {name: None for model in models for name in model}}
-
-
-def load(fn: str):
-  if fn.endswith(".index.json"):
-    with open(fn) as fp:
-      weight_map = json.load(fp)["weight_map"]
-    parts = {n: load(str(Path(fn).parent / Path(n).name)) for n in set(weight_map.values())}
-    return {k: parts[n][k] for k, n in weight_map.items()}
-  elif fn.endswith(".safetensors"):
-    return safe_load(fn)
-  else:
-    return torch_load(fn)
-
-
-def build_transformer(model_path: Path, shard: Shard, model_size="8B", quantize=None, device=None):
+def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=None):
   # build model
   linear = nn.Linear
   with Context(THREEFRY=0):
-    model = Transformer(**MODEL_PARAMS[model_size]["args"], shard=shard, linear=linear, max_context=8192, jit=False)
+    model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True, shard=shard)
 
   # load weights
   if model_path.is_dir():
-    if (model_path / "model.safetensors.index.json").exists():
-      weights = load(str(model_path / "model.safetensors.index.json"))
-    elif (model_path / "model.safetensors").exists():
-      weights = load(str(model_path / "model.safetensors"))
-    else:
-      weights = concat_weights(
-        [load(str(model_path / f"consolidated.{i:02d}.pth")) for i in range(MODEL_PARAMS[model_size]["files"])],
-        device[0] if isinstance(device, tuple) else device,
-      )
+    if (model_path / "model.safetensors.index.json").exists(): weights = load(str(model_path / "model.safetensors.index.json"), shard)
+    elif (model_path / "model.safetensors").exists(): weights = load(str(model_path / "model.safetensors"), shard)
+    else: weights = concat_weights([load(str(model_path / f"consolidated.{i:02d}.pth"), shard) for i in range(MODEL_PARAMS[model_size]["files"])], device[0] if isinstance(device, tuple) else device)
   else:
-    weights = load(str(model_path))
+    weights = load(str(model_path), shard)
   if "model.embed_tokens.weight" in weights:
-    weights = convert_from_huggingface(
-      weights,
-      model,
-      MODEL_PARAMS[model_size]["args"]["n_heads"],
-      MODEL_PARAMS[model_size]["args"]["n_kv_heads"],
-      shard=shard,
-    )
+    weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"])
   weights = fix_bf16(weights)
 
   with Context(BEAM=0):
-    # quantize
-    if quantize is not None:
-      weights = linear.quantize(weights, device)
-      for _, v in weights.items():
-        v.realize()
-
-    # shard
-    if isinstance(device, tuple):
-      for k, v in nn.state.get_state_dict(model).items():
-        if "scale" in k:
-          v.shard_(device, axis=None)  # from quantized
-        elif ".attention." in k:
-          v.shard_(device, axis=-1)
-        elif ".feed_forward.w1." in k:
-          v.shard_(device, axis=0)
-        elif ".feed_forward.w3." in k:
-          v.shard_(device, axis=0)
-        elif ".feed_forward." in k:
-          v.shard_(device, axis=-1)
-        elif "tok_embeddings.weight" in k:
-          v.shard_(device, axis=0)
-        elif "output.weight" in k:
-          v.shard_(device, axis=0)
-        else:
-          v.shard_(device, axis=None)
-
     # replace weights in model
-    load_state_dict(model, weights, strict=False, consume=True)
+    load_state_dict(model, weights, strict=False, consume=False) # consume=True
   return model
 
-
-# default settings
-TEMPERATURE = 0  # 0.85
-TOP_K = 25
-TOP_P = 0.9
-ALPHA_F = 0.1
-ALPHA_P = 0.0
-
-
-def prefill(model, toks, start_pos=0):
-  # prefill the model
-  for tok in tqdm(toks):
-    GlobalCounters.reset()
-    model(Tensor([[tok]]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).realize()
-    start_pos += 1
-  return start_pos
-
-
 class TinygradDynamicShardInferenceEngine(InferenceEngine):
-  def __init__(self):
+  def __init__(self, shard_downloader: ShardDownloader):
     self.shard = None
+    self.shard_downloader = shard_downloader
 
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
-    # TODO: we need to refactor models/llamaa to handle per-request-kv-cache. right now it's shared between requests.
     await self.ensure_shard(shard)
-    start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0
+    start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
+    n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
 
     toks = self.tokenizer.encode(prompt)
-    start_pos = prefill(self.model, toks[:-1], start_pos=start_pos)
-    last_tok = toks[-1]
-
-    output_data = np.array([self.model(Tensor([[last_tok]]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist()])
-    if output_data.size == 1:
-      start_pos += 1
+    h = self.model(Tensor([toks]), start_pos, TEMPERATURE)
 
-    return (
-      output_data,
-      json.dumps({"start_pos": start_pos}),
-      output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id],
-    )
+    if h.shape == (1,):
+      start_pos += len(toks)
+      n_captured_toks = 1
+      return np.array([[h.item()]]), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), h.item() == self.tokenizer.eos_token_id
+    else:
+      n_captured_toks += len(toks)
+      return h.numpy(), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks + 1}), False
 
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
     await self.ensure_shard(shard)
-    start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0
+    start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
+    n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
 
-    output_data: np.ndarray = np.array([self.model(Tensor([input_data]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist()])
-    if output_data.size == 1:
-      start_pos += 1
+    h = self.model(Tensor(input_data), start_pos, TEMPERATURE).realize()
 
-    return (
-      output_data,
-      json.dumps({"start_pos": start_pos}),
-      output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id],
-    )
+    if h.shape == (1,):
+      start_pos += n_captured_toks
+      n_captured_toks = 1
+      return np.array([[h.item()]]), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), h.item() == self.tokenizer.eos_token_id
+    else:
+      return h.numpy(), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), False
 
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:
       return
 
-    model_path = Path(shard.model_id)
-    models_dir = Path(_cache_dir) / "tinygrad" / "downloads"
-    model_path = models_dir / shard.model_id
-    size = "8B"
-    if Path(model_path / "tokenizer_config.json").exists():
-      model = model_path
-    else:
-
-      if DEBUG >= 2: print(f"Downloading tinygrad model {shard.model_id}...")
-      if shard.model_id.lower().find("llama3-8b-sfr") != -1:
-        num_files = 4
-        for i in range(num_files):
-          await fetch_async(
-            f"https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/model-{(i+1):05d}-of-{num_files:05d}.safetensors",
-            f"model-{(i+1):05d}-of-{num_files:05d}.safetensors",
-            subdir=shard.model_id,
-          )
-        await fetch_async(
-          "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/config.json",
-          "config.json",
-          subdir=shard.model_id,
-        )
-        model = await fetch_async(
-          "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/raw/main/model.safetensors.index.json",
-          "model.safetensors.index.json",
-          subdir=shard.model_id,
-        )
-        await fetch_async(
-          "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/special_tokens_map.json",
-          "special_tokens_map.json",
-          subdir=shard.model_id,
-        )
-        await fetch_async(
-          "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/tokenizer.json",
-          "tokenizer.json",
-          subdir=shard.model_id,
-        )
-        await fetch_async(
-          "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/tokenizer_config.json",
-          "tokenizer_config.json",
-          subdir=shard.model_id,
-        )
-        size = "8B"
-      elif shard.model_id.lower().find("llama3-70b-sfr") != -1:
-        raise NotImplementedError("llama3-70b-sfr is not implemented for tinygrad")
-        # fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-70B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir=shard.model_id)
-        # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00001-of-00004.safetensors", "model-00001-of-00004.safetensors", subdir=shard.model_id)
-        # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00002-of-00004.safetensors", "model-00002-of-00004.safetensors", subdir=shard.model_id)
-        # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00003-of-00004.safetensors", "model-00003-of-00004.safetensors", subdir=shard.model_id)
-        # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00004-of-00004.safetensors", "model-00004-of-00004.safetensors", subdir=shard.model_id)
-        # model = fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/raw/main/model.safetensors.index.json", "model.safetensors.index.json", subdir=shard.model_id)
-        # size = "70B"
-      else:
-        raise ValueError(f"tinygrad doesnt currently support arbitrary model downloading. unsupported model: {shard.model_id}")
-
-    model = build_transformer(model_path, shard=shard, model_size=size)
-    from transformers import AutoTokenizer
-    tokenizer = AutoTokenizer.from_pretrained(str((model_path if model_path.is_dir() else model_path.parent)))
-
+    model_path = await self.shard_downloader.ensure_shard(shard)
+    self.model = build_transformer(model_path, shard, model_size="8B")
+    self.tokenizer = AutoTokenizer.from_pretrained(str((model_path if model_path.is_dir() else model_path.parent)))
     self.shard = shard
-    self.model = model
-    self.tokenizer = tokenizer
-
-  def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]):
-    pass

+ 53 - 125
exo/inference/tinygrad/models/llama.py

@@ -1,26 +1,22 @@
 from typing import Tuple, Union, Optional, Dict, Any
 from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
 from tinygrad.helpers import getenv
-from exo.inference.shard import Shard
-
 
 # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
 def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half) -> Tensor:
-  freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[: (dim // 2)] / dim))
+  freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
   freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
   # TODO: move dtype outside this
-  return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim // 2, 2)
-
+  return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim//2, 2)
 
 # (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
 def complex_mult(A, c, d):
-  a, b = A[..., 0:1], A[..., 1:2]
-  ro = a * c - b * d
-  co = a * d + b * c
+  a,b = A[..., 0:1], A[..., 1:2]
+  ro = a*c - b*d
+  co = a*d + b*c
   return ro.cat(co, dim=-1)
 
-
-def apply_rotary_emb(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> Tuple[Tensor, Tensor]:
+def apply_rotary_emb(xq:Tensor, xk:Tensor, freqs_cis:Tensor) -> Tuple[Tensor, Tensor]:
   assert freqs_cis.shape[1] == xq.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
   xq = xq.reshape(*xq.shape[0:-1], -1, 2)
   xk = xk.reshape(*xk.shape[0:-1], -1, 2)
@@ -30,19 +26,16 @@ def apply_rotary_emb(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> Tuple[Tensor,
   xk_out = complex_mult(xk, c, d)
   return xq_out.flatten(3), xk_out.flatten(3)
 
-
-def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
+def repeat_kv(x:Tensor, n_rep:int) -> Tensor:
   bs, seqlen, n_kv_heads, head_dim = x.shape
-  if n_rep == 1:
-    return x
+  if n_rep == 1: return x
   # NOTE: this is different from x.repeat((1, 1, n_rep, 1))
   return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
 
-
 class Attention:
   def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
     self.n_heads = n_heads
-    self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads  # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
+    self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
     self.head_dim = dim // n_heads
     self.n_rep = self.n_heads // self.n_kv_heads
     self.max_context = max_context
@@ -52,8 +45,14 @@ class Attention:
     self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
     self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
 
-  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]) -> Tensor:
-    xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
+  def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]) -> Tensor:
+    if getenv("WQKV"):
+      if not hasattr(self, 'wqkv'): self.wqkv = Tensor.cat(self.wq.weight, self.wk.weight, self.wv.weight)
+      xqkv = x @ self.wqkv.T
+      xq, xk, xv = xqkv.split([self.wq.weight.shape[0], self.wk.weight.shape[0], self.wv.weight.shape[0]], dim=2)
+    else:
+      xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
+
     xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
     xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
     xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
@@ -66,14 +65,14 @@ class Attention:
       self.cache_kv = Tensor.zeros(2, bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype).contiguous().realize()
       if isinstance(x.device, tuple):
         # TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded
-        self.cache_kv.shard_((x.device), axis=None).realize()
+        self.cache_kv.shard_((x.device), axis=3 if getenv("SHARD_KVCACHE") else None).realize()
 
     # update the cache
     assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}"
-    self.cache_kv.shrink((None, None, (start_pos, start_pos + seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
+    self.cache_kv.shrink((None, None, (start_pos, start_pos+seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
 
-    keys = self.cache_kv[0].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xk
-    values = self.cache_kv[1].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xv
+    keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xk
+    values = self.cache_kv[1].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xv
 
     keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
     xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
@@ -81,39 +80,26 @@ class Attention:
     attn = attn.reshape(bsz, seqlen, -1)
     return self.wo(attn)
 
-
 class FeedForward:
-  def __init__(self, dim: int, hidden_dim: int, linear=nn.Linear):
+  def __init__(self, dim:int, hidden_dim:int, linear=nn.Linear):
     self.w1 = linear(dim, hidden_dim, bias=False)
     self.w2 = linear(hidden_dim, dim, bias=False)
-    self.w3 = linear(dim, hidden_dim, bias=False)  # the gate in Gated Linear Unit
-
-  def __call__(self, x: Tensor) -> Tensor:
-    return self.w2(self.w1(x).silu() * self.w3(x))  # SwiGLU [arxiv/2002.05202, eq (5)]
+    self.w3 = linear(dim, hidden_dim, bias=False) # the gate in Gated Linear Unit
 
+  def __call__(self, x:Tensor) -> Tensor:
+    return self.w2(self.w1(x).silu() * self.w3(x)) # SwiGLU [arxiv/2002.05202, eq (5)]
 
 class TransformerBlock:
-  def __init__(
-    self,
-    dim: int,
-    hidden_dim: int,
-    n_heads: int,
-    n_kv_heads: int,
-    norm_eps: float,
-    max_context: int,
-    linear=nn.Linear,
-    feed_forward=FeedForward,
-  ):
+  def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int, linear=nn.Linear, feed_forward=FeedForward):
     self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear)
     self.feed_forward = feed_forward(dim, hidden_dim, linear)
     self.attention_norm = nn.RMSNorm(dim, norm_eps)
     self.ffn_norm = nn.RMSNorm(dim, norm_eps)
 
-  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]):
+  def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]):
     h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
     return (h + self.feed_forward(self.ffn_norm(h))).contiguous()
 
-
 # standard openai sampling
 def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
   assert logits.ndim == 1, "only works on 1d tensors"
@@ -121,8 +107,7 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
   assert 0 <= k <= logits.numel(), "k must be between 0 and numel"
 
   # if temperature is very low just use argmax
-  if temp < 1e-6:
-    return logits.argmax()
+  if temp < 1e-6: return logits.argmax()
 
   # alpha sampling
   if af or ap:
@@ -136,16 +121,10 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
   # softmax
   t = (logits / temp).softmax()
 
-  counter, counter2 = (
-    Tensor.arange(t.numel(), device=logits.device).contiguous(),
-    Tensor.arange(t.numel() - 1, -1, -1, device=logits.device).contiguous(),
-  )
+  counter, counter2 = Tensor.arange(t.numel(), device=logits.device).contiguous(), Tensor.arange(t.numel() - 1, -1, -1, device=logits.device).contiguous()
   # top k
   if k:
-    output, output_indices = (
-      Tensor.zeros(k, device=logits.device).contiguous(),
-      Tensor.zeros(k, device=logits.device, dtype=dtypes.int32).contiguous(),
-    )
+    output, output_indices = Tensor.zeros(k, device=logits.device).contiguous(), Tensor.zeros(k, device=logits.device, dtype=dtypes.int32).contiguous()
     for i in range(k):
       t_argmax = (t.numel() - ((t == (t_max := t.max())) * counter2).max() - 1).cast(dtypes.default_int)
       output = output + t_max.unsqueeze(0).pad(((i, k - i - 1),))
@@ -170,84 +149,48 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
 
   return output_token
 
+from exo.inference.shard import Shard
 
 class Transformer:
-  def __init__(
-    self,
-    dim: int,
-    hidden_dim: int,
-    n_heads: int,
-    n_layers: int,
-    norm_eps: float,
-    vocab_size,
-    shard: Shard,
-    linear=nn.Linear,
-    n_kv_heads=None,
-    rope_theta=10000,
-    max_context=1024,
-    jit=True,
-    feed_forward=FeedForward,
-  ):
-    self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(shard.end_layer - shard.start_layer + 1)]
+  def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, shard: Shard=None, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward):
+    self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
     self.norm = nn.RMSNorm(dim, norm_eps)
     self.tok_embeddings = nn.Embedding(vocab_size, dim)
     self.output = nn.Linear(dim, vocab_size, bias=False)
     self.max_context = max_context
-    self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context * 2, rope_theta)
+    self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context * 2, rope_theta).contiguous()
     self.forward_jit = TinyJit(self.forward) if jit else None
     self.shard = shard
 
-  def forward(
-    self,
-    h: Tensor,
-    start_pos: Union[Variable, int],
-    temperature: float,
-    top_k: int,
-    top_p: float,
-    alpha_f: float,
-    alpha_p: float,
-  ):
-    seqlen = h.shape[1]
-    freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos + seqlen), None, None, None))
+  def forward(self, x:Tensor, start_pos:Union[Variable,int], temperature:float, top_k:int, top_p:float, alpha_f:float, alpha_p:float):
+    seqlen = x.shape[1]
+    freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos+seqlen),None,None,None))
+    mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-100000000"), dtype=x.dtype, device=x.device).triu(start_pos+1).realize() if seqlen > 1 else None
 
     if self.shard.is_first_layer():
-      h = self.tok_embeddings(h)
-    mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos + 1).realize() if seqlen > 1 else None
+      h = self.tok_embeddings(x)
+    else:
+      h = x
 
-    for layer in self.layers:
+    for i in range(self.shard.start_layer, self.shard.end_layer + 1):
+      layer = self.layers[i]
       h = layer(h, start_pos, freqs_cis, mask)
 
     if self.shard.is_last_layer():
       logits = self.output(self.norm(h)).float()[:, -1, :]
       return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p).realize()
     else:
-      return h.realize()
-
-  def __call__(
-    self,
-    tokens: Tensor,
-    start_pos: Variable,
-    temperature: float = 0.0,
-    top_k: int = 0,
-    top_p: float = 0.8,
-    alpha_f: float = 0.0,
-    alpha_p: float = 0.0,
-  ):
+      return h
+
+  def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0, top_k:int=0, top_p:float=0.8, alpha_f:float=0.0, alpha_p:float=0.0):
     # TODO: better way to handle the first call v.s. the rest?
-    # if tokens.shape[0:2] == (1,1) and self.forward_jit is not None:
-    #   return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
+    if tokens.shape[0:2] == (1,1) and self.forward_jit is not None:
+      return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
     return self.forward(tokens, start_pos, temperature, top_k, top_p, alpha_f, alpha_p)
 
-  def reset(self):
-    for layer in self.layers:
-      if hasattr(layer.attention, "cache_kv"):
-        layer.attention.cache_kv = layer.attention.cache_kv.zeros_like()
-
-
 # *** helpers ***
 
-
-def convert_from_huggingface(weights: Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int, shard: Shard):
+def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int):
   def permute(v: Tensor, n_heads: int):
     return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
 
@@ -255,30 +198,16 @@ def convert_from_huggingface(weights: Dict[str, Tensor], model: Transformer, n_h
     "model.embed_tokens.weight": "tok_embeddings.weight",
     **{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))},
     **{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
-    **{f"model.layers.{l}.self_attn.{x}_proj.biases": f"layers.{l}.attention.w{x}.bias" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
-    **{f"model.layers.{l}.self_attn.{x}_proj.scales": f"layers.{l}.attention.w{x}.scale" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
     **{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(len(model.layers))},
-    **{f"model.layers.{l}.post_attention_layernorm.biases": f"layers.{l}.ffn_norm.bias" for l in range(len(model.layers))},
     **{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
-    **{f"model.layers.{l}.mlp.{x}_proj.biases": f"layers.{l}.feed_forward.w{y}.bias" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
-    **{f"model.layers.{l}.mlp.{x}_proj.scales": f"layers.{l}.feed_forward.w{y}.scale" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
     "model.norm.weight": "norm.weight",
     "lm_head.weight": "output.weight",
-    "lm_head.biases": "output.bias",
-    "lm_head.scales": "output.scale",
   }
   sd = {}
   for k, v in weights.items():
-    if ".rotary_emb." in k:
-      continue
+    if ".rotary_emb." in k: continue
     v = v.to(Device.DEFAULT)
     if "model.layers" in k:
-      layer_num = int(k.split(".")[2])
-      if shard.start_layer <= layer_num <= shard.end_layer:
-        k = f"model.layers.{layer_num - shard.start_layer}." + ".".join(k.split(".")[3:])
-      else:
-        continue
-
       if "q_proj" in k:
         v = permute(v, n_heads)
       elif "k_proj" in k:
@@ -286,10 +215,9 @@ def convert_from_huggingface(weights: Dict[str, Tensor], model: Transformer, n_h
     sd[keymap[k]] = v
   return sd
 
-
-def fix_bf16(weights: Dict[Any, Tensor]):
+def fix_bf16(weights:Dict[Any, Tensor]):
   if getenv("SUPPORT_BF16", 1):
     # TODO: without casting to float16, 70B llama OOM on tinybox.
-    return {k: v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
+    return {k:v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}
   # TODO: check if device supports bf16
-  return {k: v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
+  return {k:v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}

+ 38 - 0
exo/inference/tinygrad/tinygrad_helpers.py

@@ -0,0 +1,38 @@
+from tinygrad.nn.state import safe_load, torch_load
+from tinygrad import Tensor
+from pathlib import Path
+import json
+from typing import List
+from exo.inference.shard import Shard
+from exo.helpers import DEBUG
+
+# **** helper functions ****
+def concat_weights(models, device=None):
+  def convert(name) -> Tensor:
+    disk_tensors: List[Tensor] = [model[name] for model in models]
+    if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1:
+      return disk_tensors[0].to(device=device)
+    axis = 1 if name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0
+    lazy_tensors = [data.to(device=device) for data in disk_tensors]
+    return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
+  return {name: convert(name) for name in {name: None for model in models for name in model}}
+
+def load(fn:str, shard: Shard):
+  if fn.endswith('.index.json'):
+    with open(fn) as fp: weight_map = json.load(fp)['weight_map']
+    parts = {}
+    filtered_weight_map = {}
+    for k, n in weight_map.items():
+      if k.startswith("model.layers."):
+        layer_num = int(k.split('.')[2])
+        if layer_num < shard.start_layer or layer_num > shard.end_layer:
+          continue
+
+      parts[n] = load(str(Path(fn).parent / Path(n).name), shard)
+      filtered_weight_map[k] = n
+    if DEBUG >= 2: print(f"Excluded model param keys for {shard=}: {set(weight_map.keys()) - set(filtered_weight_map.keys())}")
+    return {k: parts[n][k] for k, n in filtered_weight_map.items()}
+  elif fn.endswith(".safetensors"):
+    return safe_load(fn)
+  else:
+    return torch_load(fn)

+ 6 - 6
exo/networking/grpc/grpc_server.py

@@ -48,7 +48,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     image_str = request.image_str
     request_id = request.request_id
     result = await self.node.process_prompt(shard, prompt, image_str, request_id)
-    if DEBUG >= 2: print(f"SendPrompt {shard=} {prompt=} {image_str=} {request_id=} result: {result}")
+    if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {image_str=} {request_id=} result: {result}")
     tensor_data = result.tobytes() if result is not None else None
     return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
 
@@ -64,14 +64,14 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     inference_state = request.inference_state
 
     result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
-    if DEBUG >= 2: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
+    if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
     tensor_data = result.tobytes() if result is not None else None
     return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
 
   async def GetInferenceResult(self, request, context):
     request_id = request.request_id
     result = await self.node.get_inference_result(request_id)
-    if DEBUG >= 2: print(f"GetInferenceResult {request_id=}: {result}")
+    if DEBUG >= 5: print(f"GetInferenceResult {request_id=}: {result}")
     tensor_data = result[0].tobytes() if result[0] is not None else None
     return (
       node_service_pb2.InferenceResult(
@@ -96,20 +96,20 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
       for node_id, cap in topology.nodes.items()
     }
     peer_graph = {node_id: node_service_pb2.Peers(peer_ids=peers) for node_id, peers in topology.peer_graph.items()}
-    if DEBUG >= 2: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
+    if DEBUG >= 5: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
     return node_service_pb2.Topology(nodes=nodes, peer_graph=peer_graph)
 
   async def SendResult(self, request, context):
     request_id = request.request_id
     result = request.result
     is_finished = request.is_finished
-    if DEBUG >= 2: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}")
+    if DEBUG >= 5: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}")
     self.node.on_token.trigger_all(request_id, result, is_finished)
     return node_service_pb2.Empty()
 
   async def SendOpaqueStatus(self, request, context):
     request_id = request.request_id
     status = request.status
-    if DEBUG >= 2: print(f"Received SendOpaqueStatus request: {request_id=} {status=}")
+    if DEBUG >= 5: print(f"Received SendOpaqueStatus request: {request_id=} {status=}")
     self.node.on_opaque_status.trigger_all(request_id, status)
     return node_service_pb2.Empty()

+ 14 - 21
exo/orchestration/standard_node.py

@@ -3,6 +3,7 @@ import json
 import asyncio
 import uuid
 import time
+import traceback
 from typing import List, Dict, Optional, Tuple, Union
 from exo.networking import Discovery, PeerHandle, Server
 from exo.inference.inference_engine import InferenceEngine, Shard
@@ -13,6 +14,7 @@ from exo.topology.partitioning_strategy import Partition, PartitioningStrategy,
 from exo import DEBUG
 from exo.helpers import AsyncCallbackSystem
 from exo.viz.topology_viz import TopologyViz
+from exo.download.hf.hf_helpers import RepoProgressEvent
 
 
 class StandardNode(Node):
@@ -23,7 +25,7 @@ class StandardNode(Node):
     inference_engine: InferenceEngine,
     discovery: Discovery,
     partitioning_strategy: PartitioningStrategy = None,
-    max_generate_tokens: int = 256,
+    max_generate_tokens: int = 1024,
     chatgpt_api_endpoint: Optional[str] = None,
     web_chat_url: Optional[str] = None,
     disable_tui: Optional[bool] = False,
@@ -42,6 +44,7 @@ class StandardNode(Node):
     self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
     self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
     self._on_opaque_status.register("node_status").on_next(self.on_node_status)
+    self.node_download_progress: Dict[str, RepoProgressEvent] = {}
 
   def on_node_status(self, request_id, opaque_status):
     try:
@@ -54,13 +57,14 @@ class StandardNode(Node):
             self.current_topology.active_node_id = None
       download_progress = None
       if status_data.get("type", "") == "download_progress":
-        if DEBUG >= 5: print(f"Download progress from {status_data.get('node_id')}: {status_data.get('current')}/{status_data.get('total')} ({round(status_data.get('current') / status_data.get('total') * 100, 2)}%)")
-        if status_data.get("node_id") == self.id:
-          download_progress = (status_data.get('current'), status_data.get('total'))
+        if DEBUG >= 5: print(f"Download progress from {status_data.get('node_id')}: {status_data.get('progress')}")
+        download_progress = RepoProgressEvent.from_dict(status_data.get('progress'))
+        self.node_download_progress[status_data.get('node_id')] = download_progress
       if self.topology_viz:
-        self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), download_progress)
-    except json.JSONDecodeError:
-      pass
+        self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), self.id, self.node_download_progress)
+    except Exception as e:
+      if DEBUG >= 1: print(f"Error updating visualization: {e}")
+      if DEBUG >= 1: traceback.print_exc()
 
   async def start(self, wait_for_peers: int = 0) -> None:
     await self.server.start()
@@ -231,8 +235,6 @@ class StandardNode(Node):
       return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
     except Exception as e:
       print(f"Error processing tensor for shard {shard}: {e}")
-      import traceback
-
       traceback.print_exc()
       return None
 
@@ -287,15 +289,14 @@ class StandardNode(Node):
 
   async def update_peers(self, wait_for_peers: int = 0) -> None:
     self.peers = await self.discovery.discover_peers(wait_for_peers)
-    if DEBUG >= 2: print(f"Starting with the following peers: {self.peers}")
-    if DEBUG >= 2: print("Connecting to new peers...")
     for peer in self.peers:
       is_connected = await peer.is_connected()
       if DEBUG >= 2 and is_connected:
         print(f"Already connected to {peer.id()}: {is_connected}")
       if not is_connected:
+        if DEBUG >= 2: print(f"Connecting to {peer.id()}...")
         await peer.connect()
-        if DEBUG >= 0: print(f"Connected to peer {peer.device_capabilities()} ({peer.id()=})")
+        if DEBUG >= 1: print(f"Connected to peer {peer.device_capabilities()} ({peer.id()=})")
 
   async def periodic_topology_collection(self, interval: int):
     while True:
@@ -306,9 +307,6 @@ class StandardNode(Node):
       except Exception as e:
         print(f"Error collecting topology: {e}")
 
-      if DEBUG >= 2: print("Topology collection task executed.")
-      if DEBUG >= 2: print(f"Current topology: {self.topology}")
-
   async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
     if request_id not in self.buffered_token_output:
       return None, False
@@ -328,7 +326,6 @@ class StandardNode(Node):
       next_topology.add_edge(self.id, peer.id())
 
       if peer.id() in prev_visited:
-        if DEBUG >= 2: print(f"Already visited {peer.id()}. Skipping...")
         continue
 
       if max_depth <= 0:
@@ -345,7 +342,7 @@ class StandardNode(Node):
     next_topology.active_node_id = self.topology.active_node_id  # this is not so clean.
     self.topology = next_topology
     if self.topology_viz:
-      self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology))
+      self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), self.id)
     return next_topology
 
   @property
@@ -368,8 +365,6 @@ class StandardNode(Node):
         print(f"Timeout broadcasting result to {peer.id()}")
       except Exception as e:
         print(f"Error broadcasting result to {peer.id()}: {e}")
-        import traceback
-
         traceback.print_exc()
 
     await asyncio.gather(*[send_result_to_peer(peer) for peer in self.peers], return_exceptions=True)
@@ -383,8 +378,6 @@ class StandardNode(Node):
         print(f"Timeout sending opaque status to {peer.id()}")
       except Exception as e:
         print(f"Error sending opaque status to {peer.id()}: {e}")
-        import traceback
-
         traceback.print_exc()
 
     await asyncio.gather(*[send_status_to_peer(peer) for peer in self.peers], return_exceptions=True)

+ 67 - 1
exo/viz/test_topology_viz.py

@@ -1,9 +1,63 @@
 import asyncio
 import unittest
+from datetime import timedelta
 from exo.viz.topology_viz import TopologyViz
 from exo.topology.topology import Topology
 from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
 from exo.topology.partitioning_strategy import Partition
+from exo.download.hf.hf_helpers import RepoProgressEvent, RepoFileProgressEvent
+
+
+def create_hf_repo_progress_event(
+    completed_files: int = 5,
+    total_files: int = 10,
+    downloaded_bytes: int = 500000000,
+    downloaded_bytes_this_session: int = 250000000,
+    total_bytes: int = 1000000000,
+    overall_speed: int = 5000000,
+    overall_eta: timedelta = timedelta(seconds=100),
+    file_progress: dict = None,
+    status: str = "in_progress"
+) -> RepoProgressEvent:
+    if file_progress is None:
+        file_progress = {
+            "file1.bin": RepoFileProgressEvent(
+                repo_id="repo_id",
+                repo_revision="repo_revision",
+                file_path="file1.bin",
+                downloaded=100000000,
+                downloaded_this_session=50000000,
+                total=200000000,
+                speed=1000000,
+                eta=timedelta(seconds=100),
+                status="in_progress"
+            ),
+            "file2.bin": RepoFileProgressEvent(
+                repo_id="repo_id",
+                repo_revision="repo_revision",
+                file_path="file2.bin",
+                downloaded=200000000,
+                downloaded_this_session=100000000,
+                total=200000000,
+                speed=2000000,
+                eta=timedelta(seconds=0),
+                status="complete"
+            )
+        }
+
+    return RepoProgressEvent(
+        repo_id="repo_id",
+        repo_revision="repo_revision",
+        completed_files=completed_files,
+        total_files=total_files,
+        downloaded_bytes=downloaded_bytes,
+        downloaded_bytes_this_session=downloaded_bytes_this_session,
+        total_bytes=total_bytes,
+        overall_speed=overall_speed,
+        overall_eta=overall_eta,
+        file_progress=file_progress,
+        status=status
+    )
 
 
 class TestNodeViz(unittest.IsolatedAsyncioTestCase):
@@ -30,7 +84,7 @@ class TestNodeViz(unittest.IsolatedAsyncioTestCase):
     await asyncio.sleep(2)  # Simulate running for a short time
 
   async def test_layout_generation(self):
-    self.top_viz._generate_layout()
+    # self.top_viz._generate_layout()
     self.top_viz.refresh()
     import time
 
@@ -43,6 +97,13 @@ class TestNodeViz(unittest.IsolatedAsyncioTestCase):
         Partition("node2", 0.4, 0.8),
         Partition("node3", 0.8, 0.9),
       ],
+      "node1",
+      {
+        "node1": create_hf_repo_progress_event(),
+        "node2": create_hf_repo_progress_event(),
+        "node3": create_hf_repo_progress_event(),
+        "node4": create_hf_repo_progress_event(),
+      },
     )
     time.sleep(2)
     self.topology.active_node_id = "node3"
@@ -54,6 +115,11 @@ class TestNodeViz(unittest.IsolatedAsyncioTestCase):
         Partition("node2", 0.5, 0.7),
         Partition("node4", 0.7, 0.9),
       ],
+      "node5",
+      {
+        "node1": create_hf_repo_progress_event(),
+        "node5": create_hf_repo_progress_event(),
+      },
     )
     time.sleep(2)
 

+ 110 - 38
exo/viz/topology_viz.py

@@ -1,51 +1,72 @@
 import math
-from typing import List, Optional, Tuple
-from exo.helpers import exo_text
+from typing import List, Optional, Tuple, Dict
+from exo.helpers import exo_text, pretty_print_bytes, pretty_print_bytes_per_second
 from exo.topology.topology import Topology
 from exo.topology.partitioning_strategy import Partition
+from exo.download.hf.hf_helpers import RepoProgressEvent
 from rich.console import Console
 from rich.panel import Panel
 from rich.text import Text
 from rich.live import Live
 from rich.style import Style
+from rich.table import Table
+from rich.layout import Layout
 from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
 
-
 class TopologyViz:
   def __init__(self, chatgpt_api_endpoint: str = None, web_chat_url: str = None):
     self.chatgpt_api_endpoint = chatgpt_api_endpoint
     self.web_chat_url = web_chat_url
     self.topology = Topology()
     self.partitions: List[Partition] = []
-    self.download_progress = None
+    self.node_id = None
+    self.node_download_progress: Dict[str, RepoProgressEvent] = {}
 
     self.console = Console()
-    self.panel = Panel(self._generate_layout(), title="Exo Cluster (0 nodes)", border_style="bright_yellow")
-    self.live_panel = Live(self.panel, auto_refresh=False, console=self.console)
+    self.layout = Layout()
+    self.layout.split(
+      Layout(name="main"),
+      Layout(name="download", size=15)
+    )
+    self.main_panel = Panel(self._generate_main_layout(), title="Exo Cluster (0 nodes)", border_style="bright_yellow")
+    self.download_panel = Panel("", title="Download Progress", border_style="cyan")
+    self.layout["main"].update(self.main_panel)
+    self.layout["download"].update(self.download_panel)
+    self.live_panel = Live(self.layout, auto_refresh=False, console=self.console)
     self.live_panel.start()
 
-  def update_visualization(self, topology: Topology, partitions: List[Partition], download_progress: Optional[Tuple[int, int]] = None):
+  def update_visualization(self, topology: Topology, partitions: List[Partition], node_id: Optional[str] = None, node_download_progress: Dict[str, RepoProgressEvent] = {}):
     self.topology = topology
     self.partitions = partitions
-    self.download_progress = download_progress
+    self.node_id = node_id
+    if node_download_progress:
+      self.node_download_progress = node_download_progress
     self.refresh()
 
   def refresh(self):
-    self.panel.renderable = self._generate_layout()
+    self.main_panel.renderable = self._generate_main_layout()
     # Update the panel title with the number of nodes and partitions
     node_count = len(self.topology.nodes)
-    self.panel.title = f"Exo Cluster ({node_count} node{'s' if node_count != 1 else ''}){f' {self.download_progress[0]/self.download_progress[1]:.2%} Downloaded' if self.download_progress else ''}"
-    self.live_panel.update(self.panel, refresh=True)
+    self.main_panel.title = f"Exo Cluster ({node_count} node{'s' if node_count != 1 else ''})"
+
+    # Only show download_panel if there are in-progress downloads
+    if any(progress.status == "in_progress" for progress in self.node_download_progress.values()):
+      self.download_panel.renderable = self._generate_download_layout()
+      self.layout["download"].visible = True
+    else:
+      self.layout["download"].visible = False
+
+    self.live_panel.update(self.layout, refresh=True)
 
-  def _generate_layout(self) -> str:
+  def _generate_main_layout(self) -> str:
     # Calculate visualization parameters
     num_partitions = len(self.partitions)
-    radius_x = 30  # Increased horizontal radius
-    radius_y = 12  # Decreased vertical radius
-    center_x, center_y = 50, 28  # Centered horizontally and moved up slightly
+    radius_x = 30
+    radius_y = 12
+    center_x, center_y = 50, 24  # Increased center_y to add more space
 
     # Generate visualization
-    visualization = [[" " for _ in range(100)] for _ in range(55)]  # Decreased height
+    visualization = [[" " for _ in range(100)] for _ in range(48)]  # Increased height to 48
 
     # Add exo_text at the top in bright yellow
     exo_lines = exo_text.split("\n")
@@ -53,7 +74,7 @@ class TopologyViz:
     max_line_length = max(len(line) for line in exo_lines)
     for i, line in enumerate(exo_lines):
       centered_line = line.center(max_line_length)
-      start_x = (100 - max_line_length) // 2 + 15  # Center the text plus empirical adjustment of 15
+      start_x = (100 - max_line_length) // 2 + 15
       colored_line = Text(centered_line, style=yellow_style)
       for j, char in enumerate(str(colored_line)):
         if 0 <= start_x + j < 100 and i < len(visualization):
@@ -68,9 +89,9 @@ class TopologyViz:
 
     info_start_y = len(exo_lines) + 1
     for i, line in enumerate(info_lines):
-      start_x = (100 - len(line)) // 2 + 15  # Center the info lines plus empirical adjustment of 15
+      start_x = (100 - len(line)) // 2 + 15
       for j, char in enumerate(line):
-        if 0 <= start_x + j < 100 and info_start_y + i < 55:
+        if 0 <= start_x + j < 100 and info_start_y + i < 48:
           visualization[info_start_y + i][start_x + j] = char
 
     # Calculate total FLOPS and position on the bar
@@ -78,13 +99,13 @@ class TopologyViz:
     bar_pos = (math.tanh(total_flops / 20 - 2) + 1) / 2
 
     # Add GPU poor/rich bar
-    bar_width = 30  # Increased bar width
-    bar_start_x = (100 - bar_width) // 2  # Center the bar
-    bar_y = info_start_y + len(info_lines) + 1  # Position the bar below the info section with two cells of space
+    bar_width = 30
+    bar_start_x = (100 - bar_width) // 2
+    bar_y = info_start_y + len(info_lines) + 1
 
     # Create a gradient bar using emojis
     gradient_bar = Text()
-    emojis = ["🟥", "🟧", "🟨", "🟩"]  # Red, Orange, Yellow, Green
+    emojis = ["🟥", "🟧", "🟨", "🟩"]
     for i in range(bar_width):
       emoji_index = min(int(i / (bar_width / len(emojis))), len(emojis) - 1)
       gradient_bar.append(emojis[emoji_index])
@@ -106,6 +127,9 @@ class TopologyViz:
     visualization[bar_y + 1][pos_x - len(flops_str) // 2 : pos_x + len(flops_str) // 2 + len(flops_str) % 2] = flops_str
     visualization[bar_y + 2][pos_x] = "▲"
 
+    # Add an extra empty line for spacing
+    bar_y += 4
+
     for i, partition in enumerate(self.partitions):
       device_capabilities = self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES)
 
@@ -113,11 +137,13 @@ class TopologyViz:
       x = int(center_x + radius_x * math.cos(angle))
       y = int(center_y + radius_y * math.sin(angle))
 
-      # Place node with different color for active node
+      # Place node with different color for active node and this node
       if partition.node_id == self.topology.active_node_id:
-        visualization[y][x] = "🔴"  # Red circle for active node
+        visualization[y][x] = "🔴"
+      elif partition.node_id == self.node_id:
+        visualization[y][x] = "🟢"
       else:
-        visualization[y][x] = "🔵"  # Blue circle for inactive nodes
+        visualization[y][x] = "🔵"
 
       # Place node info (model, memory, TFLOPS, partition) on three lines
       node_info = [
@@ -127,28 +153,27 @@ class TopologyViz:
       ]
 
       # Calculate info position based on angle
-      info_distance_x = radius_x + 6  # Increased horizontal distance
-      info_distance_y = radius_y + 3  # Decreased vertical distance
+      info_distance_x = radius_x + 6
+      info_distance_y = radius_y + 3
       info_x = int(center_x + info_distance_x * math.cos(angle))
       info_y = int(center_y + info_distance_y * math.sin(angle))
 
       # Adjust text position to avoid overwriting the node icon and prevent cutoff
-      if info_x < x:  # Text is to the left of the node
+      if info_x < x:
         info_x = max(0, x - len(max(node_info, key=len)) - 1)
-      elif info_x > x:  # Text is to the right of the node
+      elif info_x > x:
         info_x = min(99 - len(max(node_info, key=len)), info_x)
 
       # Adjust for top and bottom nodes
-      if 5 * math.pi / 4 < angle < 7 * math.pi / 4:  # Node is near the top
-        info_x += 4  # Shift text slightly to the right
-      elif math.pi / 4 < angle < 3 * math.pi / 4:  # Node is near the bottom
-        info_x += 3  # Shift text slightly to the right
-        info_y -= 2  # Move text up by two cells
+      if 5 * math.pi / 4 < angle < 7 * math.pi / 4:
+        info_x += 4
+      elif math.pi / 4 < angle < 3 * math.pi / 4:
+        info_x += 3
+        info_y -= 2
 
       for j, line in enumerate(node_info):
         for k, char in enumerate(line):
-          if 0 <= info_y + j < 55 and 0 <= info_x + k < 100:  # Updated height check
-            # Ensure we're not overwriting the node icon
+          if 0 <= info_y + j < 48 and 0 <= info_x + k < 100:
             if info_y + j != y or info_x + k != x:
               visualization[info_y + j][info_x + k] = char
 
@@ -163,8 +188,55 @@ class TopologyViz:
       for step in range(1, steps):
         line_x = int(x + (next_x - x) * step / steps)
         line_y = int(y + (next_y - y) * step / steps)
-        if 0 <= line_y < 55 and 0 <= line_x < 100:  # Updated height check
+        if 0 <= line_y < 48 and 0 <= line_x < 100:
           visualization[line_y][line_x] = "-"
 
     # Convert to string
     return "\n".join("".join(str(char) for char in row) for row in visualization)
+
+  def _generate_download_layout(self) -> Table:
+    summary = Table(show_header=False, box=None, padding=(0, 1), expand=True)
+    summary.add_column("Info", style="cyan", no_wrap=True, ratio=50)
+    summary.add_column("Progress", style="cyan", no_wrap=True, ratio=40)
+    summary.add_column("Percentage", style="cyan", no_wrap=True, ratio=10)
+
+    # Current node download progress
+    if self.node_id in self.node_download_progress:
+        download_progress = self.node_download_progress[self.node_id]
+        title = f"Downloading model {download_progress.repo_id}@{download_progress.repo_revision} ({download_progress.completed_files}/{download_progress.total_files}):"
+        summary.add_row(Text(title, style="bold"))
+        progress_info = f"{pretty_print_bytes(download_progress.downloaded_bytes)} / {pretty_print_bytes(download_progress.total_bytes)} ({pretty_print_bytes_per_second(download_progress.overall_speed)})"
+        summary.add_row(progress_info)
+
+        eta_info = f"{download_progress.overall_eta}"
+        summary.add_row(eta_info)
+
+        summary.add_row("")  # Empty row for spacing
+
+        for file_path, file_progress in download_progress.file_progress.items():
+            if file_progress.status != "complete":
+                progress = int(file_progress.downloaded / file_progress.total * 30)
+                bar = f"[{'=' * progress}{' ' * (30 - progress)}]"
+                percentage = f"{file_progress.downloaded / file_progress.total * 100:.0f}%"
+                summary.add_row(Text(file_path[:30], style="cyan"), bar, percentage)
+
+    summary.add_row("")  # Empty row for spacing
+
+    # Other nodes download progress summary
+    summary.add_row(Text("Other Nodes Download Progress:", style="bold"))
+    for node_id, progress in self.node_download_progress.items():
+        if node_id != self.node_id:
+            device = self.topology.nodes.get(node_id)
+            partition = next((p for p in self.partitions if p.node_id == node_id), None)
+            partition_info = f"[{partition.start:.2f}-{partition.end:.2f}]" if partition else ""
+            percentage = progress.downloaded_bytes / progress.total_bytes * 100 if progress.total_bytes > 0 else 0
+            speed = pretty_print_bytes_per_second(progress.overall_speed)
+            device_info = f"{device.model if device else 'Unknown Device'} {device.memory // 1024 if device else '?'}GB {partition_info}"
+            progress_info = f"{progress.repo_id}@{progress.repo_revision} ({speed})"
+            progress_bar = f"[{'=' * int(percentage // 3.33)}{' ' * (30 - int(percentage // 3.33))}]"
+            percentage_str = f"{percentage:.1f}%"
+            eta_str = f"{progress.overall_eta}"
+            summary.add_row(device_info, progress_info, percentage_str)
+            summary.add_row("", progress_bar, eta_str)
+
+    return summary

+ 53 - 0
extra/download_hf.py

@@ -0,0 +1,53 @@
+import argparse
+import asyncio
+from exo.download.hf.hf_helpers import download_all_files, RepoProgressEvent
+
+DEFAULT_ALLOW_PATTERNS = [
+    "*.json",
+    "*.py",
+    "tokenizer.model",
+    "*.tiktoken",
+    "*.txt",
+    "*.safetensors",
+]
+# Always ignore `.git` and `.cache/huggingface` folders in commits
+DEFAULT_IGNORE_PATTERNS = [
+    ".git",
+    ".git/*",
+    "*/.git",
+    "**/.git/**",
+    ".cache/huggingface",
+    ".cache/huggingface/*",
+    "*/.cache/huggingface",
+    "**/.cache/huggingface/**",
+]
+
+async def main(repo_id, revision="main", allow_patterns=None, ignore_patterns=None):
+    async def progress_callback(event: RepoProgressEvent):
+        print(f"Overall Progress: {event.completed_files}/{event.total_files} files, {event.downloaded_bytes}/{event.total_bytes} bytes")
+        print(f"Estimated time remaining: {event.overall_eta}")
+        print("File Progress:")
+        for file_path, progress in event.file_progress.items():
+            status_icon = {
+                'not_started': '⚪',
+                'in_progress': '🔵',
+                'complete': '✅'
+            }[progress.status]
+            eta_str = str(progress.eta)
+            print(f"{status_icon} {file_path}: {progress.downloaded}/{progress.total} bytes, "
+                  f"Speed: {progress.speed:.2f} B/s, ETA: {eta_str}")
+        print("\n")
+
+    await download_all_files(repo_id, revision, progress_callback, allow_patterns, ignore_patterns)
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description="Download files from a Hugging Face model repository.")
+    parser.add_argument("--repo-id", required=True, help="The repository ID (e.g., 'meta-llama/Meta-Llama-3.1-8B-Instruct')")
+    parser.add_argument("--revision", default="main", help="The revision to download (branch, tag, or commit hash)")
+    parser.add_argument("--allow-patterns", nargs="*", default=None, help="Patterns of files to allow (e.g., '*.json' '*.safetensors')")
+    parser.add_argument("--ignore-patterns", nargs="*", default=None, help="Patterns of files to ignore (e.g., '.*')")
+
+    args = parser.parse_args()
+
+    asyncio.run(main(args.repo_id, args.revision, args.allow_patterns, args.ignore_patterns))

+ 30 - 5
main.py

@@ -2,13 +2,17 @@ import argparse
 import asyncio
 import signal
 import json
-import uuid
+import time
+import traceback
 from exo.orchestration.standard_node import StandardNode
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.grpc.grpc_discovery import GRPCDiscovery
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 from exo.api import ChatGPTAPI
+from exo.download.shard_download import ShardDownloader, RepoProgressEvent
+from exo.download.hf.hf_shard_download import HFShardDownloader
 from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_inference_engine, get_system_info, get_or_create_node_id
+from exo.inference.shard import Shard
 
 # parse args
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
@@ -22,7 +26,7 @@ parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery
 parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
 parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
 parser.add_argument("--chatgpt-api-response-timeout-secs", type=int, default=90, help="ChatGPT API response timeout in seconds")
-parser.add_argument("--max-generate-tokens", type=int, default=256, help="Max tokens to generate in each request")
+parser.add_argument("--max-generate-tokens", type=int, default=1024, help="Max tokens to generate in each request")
 parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use")
 parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
 args = parser.parse_args()
@@ -32,9 +36,10 @@ print_yellow_exo()
 system_info = get_system_info()
 print(f"Detected system: {system_info}")
 
+shard_downloader: ShardDownloader = HFShardDownloader()
 inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
-inference_engine = get_inference_engine(inference_engine_name)
-print(f"Using inference engine: {inference_engine.__class__.__name__}")
+inference_engine = get_inference_engine(inference_engine_name, shard_downloader)
+print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")
 
 if args.node_port is None:
     args.node_port = find_available_port(args.node_host)
@@ -57,10 +62,30 @@ server = GRPCServer(node, args.node_host, args.node_port)
 node.server = server
 api = ChatGPTAPI(node, inference_engine.__class__.__name__, response_timeout_secs=args.chatgpt_api_response_timeout_secs)
 node.on_token.register("main_log").on_next(lambda _, tokens, __: print(inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens))
+def preemptively_start_download(request_id: str, opaque_status: str):
+    try:
+        status = json.loads(opaque_status)
+        if status.get("type") == "node_status" and status.get("status") == "start_process_prompt":
+            current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
+            if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
+            asyncio.create_task(shard_downloader.ensure_shard(current_shard))
+    except Exception as e:
+        if DEBUG >= 2:
+            print(f"Failed to preemptively start download: {e}")
+            traceback.print_exc()
+node.on_opaque_status.register("start_download").on_next(preemptively_start_download)
 if args.prometheus_client_port:
     from exo.stats.metrics import start_metrics_server
     start_metrics_server(node, args.prometheus_client_port)
-inference_engine.set_on_download_progress(lambda current, total: asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "current": current, "total": total}))))
+
+last_broadcast_time = 0
+def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
+    global last_broadcast_time
+    current_time = time.time()
+    if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
+        last_broadcast_time = current_time
+        asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
+shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
 
 async def shutdown(signal, loop):
     """Gracefully shutdown the server and close the asyncio loop."""

+ 3 - 1
setup.py

@@ -6,6 +6,7 @@ from setuptools import find_packages, setup
 install_requires = [
     "aiohttp==3.9.5",
     "aiohttp_cors==0.7.0",
+    "aiofiles==24.1.0",
     "blobfile==2.1.1",
     "grpcio==1.64.1",
     "grpcio-tools==1.64.1",
@@ -21,6 +22,7 @@ install_requires = [
     "requests==2.32.3",
     "rich==13.7.1",
     "safetensors==0.4.3",
+    "tenacity==9.0.0",
     "tiktoken==0.7.0",
     "tokenizers==0.19.1",
     "tqdm==4.66.4",
@@ -33,7 +35,7 @@ install_requires = [
 if sys.platform.startswith("darwin"):
     install_requires.extend(
         [
-            "mlx==0.16.0",
+            "mlx==0.16.1",
             "mlx-lm==0.16.1",
         ]
     )