Bladeren bron

Merge pull request #124 from exo-explore/refactor_model_download

Refactor model download, refactor tinygrad
Alex Cheema 9 maanden geleden
bovenliggende
commit
9e78c42b4b

+ 4 - 4
.circleci/config.yml

@@ -17,11 +17,11 @@ commands:
             source env/bin/activate
             source env/bin/activate
 
 
             # Start first instance
             # Start first instance
-            DEBUG_DISCOVERY=9 DEBUG=9 python3 main.py --inference-engine <<parameters.inference_engine>> --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 --chatgpt-api-response-timeout-secs 900 > output1.log 2>&1 &
+            HF_HOME="$(pwd)/.hf_cache_node1" DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --inference-engine <<parameters.inference_engine>> --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 --chatgpt-api-response-timeout-secs 900 > output1.log 2>&1 &
             PID1=$!
             PID1=$!
 
 
             # Start second instance
             # Start second instance
-            DEBUG_DISCOVERY=9 DEBUG=9 python3 main.py --inference-engine <<parameters.inference_engine>> --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 --chatgpt-api-response-timeout-secs 900 > output2.log 2>&1 &
+            HF_HOME="$(pwd)/.hf_cache_node2" DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --inference-engine <<parameters.inference_engine>> --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 --chatgpt-api-response-timeout-secs 900 > output2.log 2>&1 &
             PID2=$!
             PID2=$!
 
 
             # Wait for discovery
             # Wait for discovery
@@ -132,9 +132,9 @@ jobs:
           name: Run discovery integration test
           name: Run discovery integration test
           command: |
           command: |
             source env/bin/activate
             source env/bin/activate
-            DEBUG_DISCOVERY=9 DEBUG=9 python3 main.py --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 > output1.log 2>&1 &
+            DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 > output1.log 2>&1 &
             PID1=$!
             PID1=$!
-            DEBUG_DISCOVERY=9 DEBUG=9 python3 main.py --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 > output2.log 2>&1 &
+            DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 > output2.log 2>&1 &
             PID2=$!
             PID2=$!
             sleep 10
             sleep 10
             kill $PID1 $PID2
             kill $PID1 $PID2

+ 18 - 38
exo/api/chatgpt_api.py

@@ -7,6 +7,7 @@ from transformers import AutoTokenizer, AutoProcessor
 from typing import List, Literal, Union, Dict
 from typing import List, Literal, Union, Dict
 from aiohttp import web
 from aiohttp import web
 import aiohttp_cors
 import aiohttp_cors
+import traceback
 from exo import DEBUG, VERSION
 from exo import DEBUG, VERSION
 from exo.helpers import terminal_link, PrefixDict
 from exo.helpers import terminal_link, PrefixDict
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
@@ -16,20 +17,22 @@ shard_mappings = {
   ### llama
   ### llama
   "llama-3.1-8b": {
   "llama-3.1-8b": {
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32),
   },
   },
   "llama-3.1-70b": {
   "llama-3.1-70b": {
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B", start_layer=0, end_layer=0, n_layers=80),
   },
   },
   "llama-3.1-405b": {
   "llama-3.1-405b": {
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),
   },
   },
   "llama-3-8b": {
   "llama-3-8b": {
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-8b-sfr", start_layer=0, end_layer=0, n_layers=32),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32),
   },
   },
   "llama-3-70b": {
   "llama-3-70b": {
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-70b-sfr", start_layer=0, end_layer=0, n_layers=80),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80),
   },
   },
   ### mistral
   ### mistral
   "mistral-nemo": {
   "mistral-nemo": {
@@ -76,18 +79,10 @@ class ChatCompletionRequest:
         }
         }
 
 
 
 
-def resolve_tinygrad_tokenizer(model_id: str):
-  if model_id == "llama3-8b-sfr":
-    return AutoTokenizer.from_pretrained("TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R")
-  elif model_id == "llama3-70b-sfr":
-    return AutoTokenizer.from_pretrained("TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R")
-  else:
-    raise ValueError(f"tinygrad doesnt currently support arbitrary model downloading. unsupported model: {model_id}")
-
 
 
 async def resolve_tokenizer(model_id: str):
 async def resolve_tokenizer(model_id: str):
   try:
   try:
-    if DEBUG >= 2: print(f"Trying AutoProcessor for {model_id}")
+    if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id}")
     processor = AutoProcessor.from_pretrained(model_id, use_fast=False)
     processor = AutoProcessor.from_pretrained(model_id, use_fast=False)
     if not hasattr(processor, 'eos_token_id'):
     if not hasattr(processor, 'eos_token_id'):
       processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
       processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
@@ -97,33 +92,17 @@ async def resolve_tokenizer(model_id: str):
       processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
       processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
     return processor
     return processor
   except Exception as e:
   except Exception as e:
-    if DEBUG >= 2: print(f"Failed to load processor for {model_id}. Error: {e}")
-    import traceback
-
-    if DEBUG >= 2: print(traceback.format_exc())
+    if DEBUG >= 4: print(f"Failed to load processor for {model_id}. Error: {e}")
+    if DEBUG >= 4: print(traceback.format_exc())
 
 
   try:
   try:
-    if DEBUG >= 2: print(f"Trying AutoTokenizer for {model_id}")
+    if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id}")
     return AutoTokenizer.from_pretrained(model_id)
     return AutoTokenizer.from_pretrained(model_id)
   except Exception as e:
   except Exception as e:
-    if DEBUG >= 2: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer. Error: {e}")
-    import traceback
+    if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer. Error: {e}")
+    if DEBUG >= 4: print(traceback.format_exc())
 
 
-    if DEBUG >= 2: print(traceback.format_exc())
-
-  try:
-    if DEBUG >= 2: print(f"Trying tinygrad tokenizer for {model_id}")
-    return resolve_tinygrad_tokenizer(model_id)
-  except Exception as e:
-    if DEBUG >= 2: print(f"Failed again to load tokenizer for {model_id}. Falling back to mlx tokenizer. Error: {e}")
-    import traceback
-
-    if DEBUG >= 2: print(traceback.format_exc())
-
-  if DEBUG >= 2: print(f"Trying mlx tokenizer for {model_id}")
-  from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
-
-  return load_tokenizer(await get_model_path(model_id))
+  raise ValueError(f"[TODO] Unsupported model: {model_id}")
 
 
 
 
 def generate_completion(
 def generate_completion(
@@ -326,10 +305,7 @@ class ChatGPTAPI:
     try:
     try:
       await self.node.process_prompt(shard, prompt, image_str, request_id=request_id)
       await self.node.process_prompt(shard, prompt, image_str, request_id=request_id)
     except Exception as e:
     except Exception as e:
-      if DEBUG >= 2:
-        import traceback
-
-        traceback.print_exc()
+      if DEBUG >= 2: traceback.print_exc()
       return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
       return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
 
 
     try:
     try:
@@ -370,7 +346,11 @@ class ChatGPTAPI:
             "chat.completion",
             "chat.completion",
           )
           )
           if DEBUG >= 2: print(f"Streaming completion: {completion}")
           if DEBUG >= 2: print(f"Streaming completion: {completion}")
-          await response.write(f"data: {json.dumps(completion)}\n\n".encode())
+          try:
+            await response.write(f"data: {json.dumps(completion)}\n\n".encode())
+          except Exception as e:
+            if DEBUG >= 2: print(f"Error streaming completion: {e}")
+            if DEBUG >= 2: traceback.print_exc()
 
 
         def on_result(_request_id: str, tokens: List[int], is_finished: bool):
         def on_result(_request_id: str, tokens: List[int], is_finished: bool):
           self.stream_tasks[request_id] = asyncio.create_task(stream_result(request_id, tokens, is_finished))
           self.stream_tasks[request_id] = asyncio.create_task(stream_result(request_id, tokens, is_finished))

+ 82 - 0
exo/download/download_progress.py

@@ -0,0 +1,82 @@
+from typing import Dict, Callable, Coroutine, Any, Literal
+from dataclasses import dataclass
+from datetime import timedelta
+
+@dataclass
+class RepoFileProgressEvent:
+    repo_id: str
+    repo_revision: str
+    file_path: str
+    downloaded: int
+    downloaded_this_session: int
+    total: int
+    speed: int
+    eta: timedelta
+    status: Literal["not_started", "in_progress", "complete"]
+
+    def to_dict(self):
+        return {
+            "repo_id": self.repo_id,
+            "repo_revision": self.repo_revision,
+            "file_path": self.file_path,
+            "downloaded": self.downloaded,
+            "downloaded_this_session": self.downloaded_this_session,
+            "total": self.total,
+            "speed": self.speed,
+            "eta": self.eta.total_seconds(),
+            "status": self.status
+        }
+
+    @classmethod
+    def from_dict(cls, data):
+        # Convert eta from seconds back to timedelta
+        if 'eta' in data:
+            data['eta'] = timedelta(seconds=data['eta'])
+        return cls(**data)
+
+@dataclass
+class RepoProgressEvent:
+    repo_id: str
+    repo_revision: str
+    completed_files: int
+    total_files: int
+    downloaded_bytes: int
+    downloaded_bytes_this_session: int
+    total_bytes: int
+    overall_speed: int
+    overall_eta: timedelta
+    file_progress: Dict[str, RepoFileProgressEvent]
+    status: Literal["not_started", "in_progress", "complete"]
+
+    def to_dict(self):
+        return {
+            "repo_id": self.repo_id,
+            "repo_revision": self.repo_revision,
+            "completed_files": self.completed_files,
+            "total_files": self.total_files,
+            "downloaded_bytes": self.downloaded_bytes,
+            "downloaded_bytes_this_session": self.downloaded_bytes_this_session,
+            "total_bytes": self.total_bytes,
+            "overall_speed": self.overall_speed,
+            "overall_eta": self.overall_eta.total_seconds(),
+            "file_progress": {k: v.to_dict() for k, v in self.file_progress.items()},
+            "status": self.status
+        }
+
+    @classmethod
+    def from_dict(cls, data):
+        # Convert overall_eta from seconds back to timedelta
+        if 'overall_eta' in data:
+            data['overall_eta'] = timedelta(seconds=data['overall_eta'])
+
+        # Parse file_progress
+        if 'file_progress' in data:
+            data['file_progress'] = {
+                k: RepoFileProgressEvent.from_dict(v)
+                for k, v in data['file_progress'].items()
+            }
+
+        return cls(**data)
+
+RepoFileProgressCallback = Callable[[RepoFileProgressEvent], Coroutine[Any, Any, None]]
+RepoProgressCallback = Callable[[RepoProgressEvent], Coroutine[Any, Any, None]]

+ 305 - 0
exo/download/hf/hf_helpers.py

@@ -0,0 +1,305 @@
+import asyncio
+import aiohttp
+import json
+import os
+from urllib.parse import urljoin
+from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
+from datetime import datetime, timedelta
+from fnmatch import fnmatch
+from pathlib import Path
+from typing import Generator, Iterable, TypeVar, TypedDict
+from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
+from exo.helpers import DEBUG
+from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback
+from exo.inference.shard import Shard
+
+T = TypeVar("T")
+def filter_repo_objects(
+    items: Iterable[T],
+    *,
+    allow_patterns: Optional[Union[List[str], str]] = None,
+    ignore_patterns: Optional[Union[List[str], str]] = None,
+    key: Optional[Callable[[T], str]] = None,
+) -> Generator[T, None, None]:
+    if isinstance(allow_patterns, str):
+        allow_patterns = [allow_patterns]
+    if isinstance(ignore_patterns, str):
+        ignore_patterns = [ignore_patterns]
+    if allow_patterns is not None:
+        allow_patterns = [_add_wildcard_to_directories(p) for p in allow_patterns]
+    if ignore_patterns is not None:
+        ignore_patterns = [_add_wildcard_to_directories(p) for p in ignore_patterns]
+
+    if key is None:
+        def _identity(item: T) -> str:
+            if isinstance(item, str):
+                return item
+            if isinstance(item, Path):
+                return str(item)
+            raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.")
+        key = _identity
+
+    for item in items:
+        path = key(item)
+        if allow_patterns is not None and not any(fnmatch(path, r) for r in allow_patterns):
+            continue
+        if ignore_patterns is not None and any(fnmatch(path, r) for r in ignore_patterns):
+            continue
+        yield item
+
+def _add_wildcard_to_directories(pattern: str) -> str:
+    if pattern[-1] == "/":
+        return pattern + "*"
+    return pattern
+
+def get_hf_home() -> Path:
+    """Get the Hugging Face home directory."""
+    return Path(os.environ.get("HF_HOME", Path.home() / ".cache" / "huggingface"))
+
+def get_hf_token():
+    """Retrieve the Hugging Face token from the user's HF_HOME directory."""
+    token_path = get_hf_home() / "token"
+    if token_path.exists():
+        return token_path.read_text().strip()
+    return None
+
+def get_auth_headers():
+    """Get authentication headers if a token is available."""
+    token = get_hf_token()
+    if token:
+        return {"Authorization": f"Bearer {token}"}
+    return {}
+
+def get_repo_root(repo_id: str) -> Path:
+    """Get the root directory for a given repo ID in the Hugging Face cache."""
+    sanitized_repo_id = repo_id.replace("/", "--")
+    return get_hf_home() / "hub" / f"models--{sanitized_repo_id}"
+
+async def fetch_file_list(session, repo_id, revision, path=""):
+    api_url = f"https://huggingface.co/api/models/{repo_id}/tree/{revision}"
+    url = f"{api_url}/{path}" if path else api_url
+
+    headers = get_auth_headers()
+    async with session.get(url, headers=headers) as response:
+        if response.status == 200:
+            data = await response.json()
+            files = []
+            for item in data:
+                if item["type"] == "file":
+                    files.append({"path": item["path"], "size": item["size"]})
+                elif item["type"] == "directory":
+                    subfiles = await fetch_file_list(session, repo_id, revision, item["path"])
+                    files.extend(subfiles)
+            return files
+        else:
+            raise Exception(f"Failed to fetch file list: {response.status}")
+
+
+@retry(
+    stop=stop_after_attempt(5),
+    wait=wait_exponential(multiplier=1, min=4, max=60),
+    retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError, aiohttp.ClientResponseError)),
+    reraise=True
+)
+async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[RepoFileProgressCallback] = None, use_range_request: bool = True):
+    base_url = f"https://huggingface.co/{repo_id}/resolve/{revision}/"
+    url = urljoin(base_url, file_path)
+    local_path = os.path.join(save_directory, file_path)
+
+    os.makedirs(os.path.dirname(local_path), exist_ok=True)
+
+    # Check if file already exists and get its size
+    local_file_size = os.path.getsize(local_path) if os.path.exists(local_path) else 0
+
+    headers = get_auth_headers()
+    if use_range_request:
+        headers["Range"] = f"bytes={local_file_size}-"
+
+    async with session.get(url, headers=headers) as response:
+        total_size = int(response.headers.get('Content-Length', 0))
+        downloaded_size = local_file_size
+        downloaded_this_session = 0
+        mode = 'ab' if use_range_request else 'wb'
+        if downloaded_size == total_size:
+            if DEBUG >= 2: print(f"File already downloaded: {file_path}")
+            if progress_callback:
+                await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
+            return
+
+        if response.status == 200:
+            # File doesn't support range requests or we're not using them, start from beginning
+            mode = 'wb'
+            downloaded_size = 0
+        elif response.status == 206:
+            # Partial content, resume download
+            content_range = response.headers.get('Content-Range', '')
+            try:
+                total_size = int(content_range.split('/')[-1])
+            except ValueError:
+                if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
+                return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
+        elif response.status == 416:
+            # Range not satisfiable, get the actual file size
+            content_range = response.headers.get('Content-Range', '')
+            try:
+                total_size = int(content_range.split('/')[-1])
+                if downloaded_size == total_size:
+                    if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}")
+                    if progress_callback:
+                        await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
+                    return
+            except ValueError:
+                if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
+                return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
+        else:
+            raise aiohttp.ClientResponseError(response.request_info, response.history, status=response.status, message=f"Failed to download {file_path}: {response.status}")
+
+        if downloaded_size == total_size:
+            print(f"File already downloaded: {file_path}")
+            if progress_callback:
+                await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
+            return
+
+        DOWNLOAD_CHUNK_SIZE = 32768
+        start_time = datetime.now()
+        with open(local_path, mode) as f:
+            async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
+                f.write(chunk)
+                downloaded_size += len(chunk)
+                downloaded_this_session += len(chunk)
+                if progress_callback and total_size:
+                    elapsed_time = (datetime.now() - start_time).total_seconds()
+                    speed = int(downloaded_this_session / elapsed_time) if elapsed_time > 0 else 0
+                    remaining_size = total_size - downloaded_size
+                    eta = timedelta(seconds=remaining_size / speed) if speed > 0 else timedelta(0)
+                    status = "in_progress" if downloaded_size < total_size else "complete"
+                    if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}")
+                    await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
+        if DEBUG >= 2: print(f"Downloaded: {file_path}")
+
+async def download_repo_files(repo_id: str, revision: str = "main", progress_callback: Optional[RepoProgressCallback] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None) -> Path:
+    repo_root = get_repo_root(repo_id)
+    refs_dir = repo_root / "refs"
+    snapshots_dir = repo_root / "snapshots"
+
+    # Ensure directories exist
+    refs_dir.mkdir(parents=True, exist_ok=True)
+    snapshots_dir.mkdir(parents=True, exist_ok=True)
+
+    async with aiohttp.ClientSession() as session:
+        # Fetch the commit hash for the given revision
+        api_url = f"https://huggingface.co/api/models/{repo_id}/revision/{revision}"
+        headers = get_auth_headers()
+        async with session.get(api_url, headers=headers) as response:
+            if response.status != 200:
+                raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
+            revision_info = await response.json()
+            commit_hash = revision_info['sha']
+
+        # Write the commit hash to the refs file
+        refs_file = refs_dir / revision
+        refs_file.write_text(commit_hash)
+
+        # Set up the snapshot directory
+        snapshot_dir = snapshots_dir / commit_hash
+        snapshot_dir.mkdir(exist_ok=True)
+
+        file_list = await fetch_file_list(session, repo_id, revision)
+        filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
+        total_files = len(filtered_file_list)
+        total_bytes = sum(file["size"] for file in filtered_file_list)
+        file_progress: Dict[str, RepoFileProgressEvent] = {file["path"]: RepoFileProgressEvent(repo_id, revision, file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started") for file in filtered_file_list}
+        start_time = datetime.now()
+
+        async def download_with_progress(file_info, progress_state):
+            async def file_progress_callback(event: RepoFileProgressEvent):
+                progress_state['downloaded_bytes'] += event.downloaded - file_progress[event.file_path].downloaded
+                progress_state['downloaded_bytes_this_session'] += event.downloaded_this_session - file_progress[event.file_path].downloaded_this_session
+                file_progress[event.file_path] = event
+                if progress_callback:
+                    elapsed_time = (datetime.now() - start_time).total_seconds()
+                    overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
+                    remaining_bytes = total_bytes - progress_state['downloaded_bytes']
+                    overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
+                    status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete"
+                    await progress_callback(RepoProgressEvent(repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
+
+            await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
+            progress_state['completed_files'] += 1
+            file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete")
+            if progress_callback:
+                elapsed_time = (datetime.now() - start_time).total_seconds()
+                overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
+                remaining_bytes = total_bytes - progress_state['downloaded_bytes']
+                overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
+                status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
+                await progress_callback(RepoProgressEvent(repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
+
+        progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0}
+        tasks = [download_with_progress(file_info, progress_state) for file_info in filtered_file_list]
+        await asyncio.gather(*tasks)
+
+    return snapshot_dir
+
+async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[str, str]]:
+    """
+    Retrieve the weight map from the model.safetensors.index.json file.
+
+    Args:
+        repo_id (str): The Hugging Face repository ID.
+        revision (str): The revision of the repository to use.
+
+    Returns:
+        Optional[Dict[str, str]]: The weight map if it exists, otherwise None.
+    """
+
+    # Download the index file
+    await download_repo_files(
+        repo_id=repo_id,
+        revision=revision,
+        allow_patterns="model.safetensors.index.json"
+    )
+
+    # Check if the file exists
+    repo_root = get_repo_root(repo_id)
+    snapshot_dir = repo_root / "snapshots"
+    index_file = next(snapshot_dir.glob("*/model.safetensors.index.json"), None)
+
+    if index_file and index_file.exists():
+        with open(index_file, 'r') as f:
+            index_data = json.load(f)
+        return index_data.get("weight_map")
+
+    return None
+
+def extract_layer_num(tensor_name: str) -> Optional[int]:
+    # This is a simple example and might need to be adjusted based on the actual naming convention
+    parts = tensor_name.split('.')
+    for part in parts:
+        if part.isdigit():
+            return int(part)
+    return None
+
+
+def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
+    default_patterns = [
+        "*.json",
+        "*.py",
+        "tokenizer.model",
+        "*.tiktoken",
+        "*.txt",
+    ]
+    shard_specific_patterns = []
+    if weight_map:
+        for tensor_name, filename in weight_map.items():
+            layer_num = extract_layer_num(tensor_name)
+            if layer_num is not None and shard.start_layer <= layer_num <= shard.end_layer:
+                shard_specific_patterns.append(filename)
+        sorted_file_names = sorted(weight_map.values())
+        if shard.is_first_layer():
+            shard_specific_patterns.append(sorted_file_names[0])
+        elif shard.is_last_layer():
+            shard_specific_patterns.append(sorted_file_names[-1])
+    else:
+        shard_specific_patterns = ["*.safetensors"]
+    return list(set(default_patterns + shard_specific_patterns))  # Remove duplicates

+ 70 - 0
exo/download/hf/hf_shard_download.py

@@ -0,0 +1,70 @@
+import asyncio
+import traceback
+from pathlib import Path
+from typing import Dict, List, Tuple
+from exo.inference.shard import Shard
+from exo.download.shard_download import ShardDownloader
+from exo.download.download_progress import RepoProgressEvent
+from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_weight_map, get_allow_patterns
+from exo.helpers import AsyncCallbackSystem, DEBUG
+
+class HFShardDownloader(ShardDownloader):
+    def __init__(self):
+        self.active_downloads: Dict[Shard, asyncio.Task] = {}
+        self.completed_downloads: Dict[Shard, Path] = {}
+        self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
+
+    async def ensure_shard(self, shard: Shard) -> Path:
+        if shard in self.completed_downloads:
+            return self.completed_downloads[shard]
+
+        # If a download on this shard is already in progress, keep that one
+        for active_shard in self.active_downloads:
+            if active_shard == shard:
+                if DEBUG >= 2: print(f"Download already in progress for {shard}. Keeping that one.")
+                return await self.active_downloads[shard]
+
+        # Cancel any downloads for this model_id on a different shard
+        existing_active_shards = [active_shard for active_shard in self.active_downloads.keys() if active_shard.model_id == shard.model_id]
+        for active_shard in existing_active_shards:
+            if DEBUG >= 2: print(f"Cancelling download for {active_shard} (replacing with {shard})")
+            task = self.active_downloads[active_shard]
+            task.cancel()
+            try:
+                await task
+            except asyncio.CancelledError:
+                pass  # This is expected when cancelling a task
+            except Exception as e:
+                if DEBUG >= 2: print(f"Error in cancelling download {active_shard}: {e}")
+                traceback.print_exc()
+        self.active_downloads = {active_shard: task for active_shard, task in self.active_downloads.items() if active_shard.model_id != shard.model_id}
+
+        # Start new download
+        download_task = asyncio.create_task(self._download_shard(shard))
+        self.active_downloads[shard] = download_task
+        try:
+            path = await download_task
+            self.completed_downloads[shard] = path
+            return path
+        finally:
+            # Ensure the task is removed even if an exception occurs
+            print(f"Removing download task for {shard}: {shard in self.active_downloads}")
+            if shard in self.active_downloads:
+                self.active_downloads.pop(shard)
+
+    async def _download_shard(self, shard: Shard) -> Path:
+        async def wrapped_progress_callback(event: RepoProgressEvent):
+            self._on_progress.trigger_all(shard, event)
+
+        weight_map = await get_weight_map(shard.model_id)
+        allow_patterns = get_allow_patterns(weight_map, shard)
+
+        return await download_repo_files(
+            repo_id=shard.model_id,
+            progress_callback=wrapped_progress_callback,
+            allow_patterns=allow_patterns
+        )
+
+    @property
+    def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
+        return self._on_progress

+ 25 - 0
exo/download/shard_download.py

@@ -0,0 +1,25 @@
+from abc import ABC, abstractmethod
+from typing import Optional, Tuple
+from pathlib import Path
+from exo.inference.shard import Shard
+from exo.download.download_progress import RepoProgressEvent
+from exo.helpers import AsyncCallbackSystem
+
+class ShardDownloader(ABC):
+    @abstractmethod
+    async def ensure_shard(self, shard: Shard) -> Path:
+        """
+        Ensures that the shard is downloaded.
+        Does not allow multiple overlapping downloads at once.
+        If you try to download a Shard which overlaps a Shard that is already being downloaded,
+        the download will be cancelled and a new download will start.
+
+        Args:
+            shard (Shard): The shard to download.
+        """
+        pass
+
+    @property
+    @abstractmethod
+    def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
+        pass

+ 27 - 3
exo/helpers.py

@@ -31,17 +31,17 @@ def get_system_info():
   return "Non-Mac, non-Linux system"
   return "Non-Mac, non-Linux system"
 
 
 
 
-def get_inference_engine(inference_engine_name):
+def get_inference_engine(inference_engine_name, shard_downloader: 'ShardDownloader'):
   if inference_engine_name == "mlx":
   if inference_engine_name == "mlx":
     from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
     from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
 
 
-    return MLXDynamicShardInferenceEngine()
+    return MLXDynamicShardInferenceEngine(shard_downloader)
   elif inference_engine_name == "tinygrad":
   elif inference_engine_name == "tinygrad":
     from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
     from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
     import tinygrad.helpers
     import tinygrad.helpers
     tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
     tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
 
 
-    return TinygradDynamicShardInferenceEngine()
+    return TinygradDynamicShardInferenceEngine(shard_downloader)
   else:
   else:
     raise ValueError(f"Inference engine {inference_engine_name} not supported")
     raise ValueError(f"Inference engine {inference_engine_name} not supported")
 
 
@@ -201,3 +201,27 @@ def get_or_create_node_id():
     except Exception as e:
     except Exception as e:
         if DEBUG >= 2: print(f"Unexpected error creating node_id: {e}")
         if DEBUG >= 2: print(f"Unexpected error creating node_id: {e}")
         return str(uuid.uuid4())
         return str(uuid.uuid4())
+
+def pretty_print_bytes(size_in_bytes: int) -> str:
+    if size_in_bytes < 1024:
+        return f"{size_in_bytes} B"
+    elif size_in_bytes < 1024 ** 2:
+        return f"{size_in_bytes / 1024:.2f} KB"
+    elif size_in_bytes < 1024 ** 3:
+        return f"{size_in_bytes / (1024 ** 2):.2f} MB"
+    elif size_in_bytes < 1024 ** 4:
+        return f"{size_in_bytes / (1024 ** 3):.2f} GB"
+    else:
+        return f"{size_in_bytes / (1024 ** 4):.2f} TB"
+
+def pretty_print_bytes_per_second(bytes_per_second: int) -> str:
+    if bytes_per_second < 1024:
+        return f"{bytes_per_second} B/s"
+    elif bytes_per_second < 1024 ** 2:
+        return f"{bytes_per_second / 1024:.2f} KB/s"
+    elif bytes_per_second < 1024 ** 3:
+        return f"{bytes_per_second / (1024 ** 2):.2f} MB/s"
+    elif bytes_per_second < 1024 ** 4:
+        return f"{bytes_per_second / (1024 ** 3):.2f} GB/s"
+    else:
+        return f"{bytes_per_second / (1024 ** 4):.2f} TB/s"

+ 1 - 6
exo/inference/inference_engine.py

@@ -1,10 +1,9 @@
 import numpy as np
 import numpy as np
 
 
-from typing import Tuple, Optional, Callable
+from typing import Tuple, Optional
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 from .shard import Shard
 from .shard import Shard
 
 
-
 class InferenceEngine(ABC):
 class InferenceEngine(ABC):
   @abstractmethod
   @abstractmethod
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
@@ -13,7 +12,3 @@ class InferenceEngine(ABC):
   @abstractmethod
   @abstractmethod
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
     pass
     pass
-
-  @abstractmethod
-  def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]):
-    pass

+ 6 - 7
exo/inference/mlx/sharded_inference_engine.py

@@ -4,13 +4,14 @@ from ..inference_engine import InferenceEngine
 from .sharded_model import StatefulShardedModel
 from .sharded_model import StatefulShardedModel
 from .sharded_utils import load_shard, get_image_from_str
 from .sharded_utils import load_shard, get_image_from_str
 from ..shard import Shard
 from ..shard import Shard
-from typing import Optional, Callable
+from typing import Optional
+from exo.download.shard_download import ShardDownloader
 
 
 
 
 class MLXDynamicShardInferenceEngine(InferenceEngine):
 class MLXDynamicShardInferenceEngine(InferenceEngine):
-  def __init__(self, on_download_progress: Callable[[int, int], None] = None):
+  def __init__(self, shard_downloader: ShardDownloader):
     self.shard = None
     self.shard = None
-    self.on_download_progress = on_download_progress
+    self.shard_downloader = shard_downloader
 
 
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
@@ -33,9 +34,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     if self.shard == shard:
     if self.shard == shard:
       return
       return
 
 
-    model_shard, self.tokenizer = await load_shard(shard.model_id, shard, on_download_progress=self.on_download_progress)
+    model_path = await self.shard_downloader.ensure_shard(shard)
+    model_shard, self.tokenizer = await load_shard(model_path, shard)
     self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
     self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
     self.shard = shard
     self.shard = shard
-
-  def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]):
-    self.on_download_progress = on_download_progress

+ 2 - 223
exo/inference/mlx/sharded_utils.py

@@ -12,22 +12,15 @@ from typing import Optional, Tuple, Union, List, Callable
 from PIL import Image
 from PIL import Image
 from io import BytesIO
 from io import BytesIO
 import base64
 import base64
-import os
-import concurrent.futures
 
 
-from exo import DEBUG
 import mlx.core as mx
 import mlx.core as mx
 import mlx.nn as nn
 import mlx.nn as nn
-from huggingface_hub import snapshot_download, list_repo_tree, get_paths_info
-from huggingface_hub.utils import filter_repo_objects
-from huggingface_hub.file_download import repo_folder_name
-from huggingface_hub.constants import HF_HUB_CACHE
-from huggingface_hub.utils._errors import RepositoryNotFoundError
 from transformers import AutoProcessor
 from transformers import AutoProcessor
 
 
 from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
 from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
 from mlx_lm.tuner.utils import apply_lora_layers
 from mlx_lm.tuner.utils import apply_lora_layers
 
 
+from exo import DEBUG
 from ..shard import Shard
 from ..shard import Shard
 
 
 
 
@@ -163,228 +156,14 @@ def load_model_shard(
   model.eval()
   model.eval()
   return model
   return model
 
 
-
-async def get_repo_size(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None):
-  it = await asyncio.to_thread(list_repo_tree, repo_id, revision=revision, repo_type=repo_type)
-  files = list(filter_repo_objects(it, allow_patterns=allow_patterns, key=lambda f: f.path))
-  return sum(file.size for file in files if hasattr(file, "size") and file.size is not None)
-
-async def monitor_progress(dir, total_size, print_progress=False, on_progress: Callable[[int, int], None] = None):
-    while True:
-      try:
-        await asyncio.sleep(0.1)
-        current_size = sum(os.path.getsize(os.path.join(root, file))
-                            for root, _, files in os.walk(dir)
-                            for file in files)
-        progress = min(current_size / total_size * 100, 100)
-        if print_progress:
-          print(f"\rProgress: {progress:.2f}% ({current_size}/{total_size} bytes)", end="", flush=True)
-        if on_progress:
-          on_progress(current_size, total_size)
-        if progress >= 100:
-          if print_progress:
-            print("\nDownload complete!")
-          break
-      except Exception as e:
-        print(f"Error monitoring progress: {e}")
-
-async def download_repo(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None):
-    with concurrent.futures.ThreadPoolExecutor() as pool:
-        return await asyncio.get_event_loop().run_in_executor(
-            pool,
-            partial(snapshot_download, repo_id=repo_id, revision=revision, allow_patterns=allow_patterns, repo_type=repo_type)
-        )
-
-async def download_async_with_progress(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None, on_progress: Callable[[int, int], None] = None):
-  storage_folder = os.path.join(HF_HUB_CACHE, repo_folder_name(repo_id=repo_id, repo_type="model"))
-  # os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'
-  # os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
-
-  total_size = await get_repo_size(repo_id)
-
-  # Create tasks for download and progress checking
-  download_task = asyncio.create_task(download_repo(repo_id, revision=revision, allow_patterns=allow_patterns, repo_type=repo_type))
-  progress_task = asyncio.create_task(monitor_progress(storage_folder, total_size, on_progress=on_progress))
-
-  # Wait for both tasks to complete
-  result = await asyncio.gather(download_task, progress_task, return_exceptions=True)
-  return result[0]  # Return the result from download_task
-
-repo_id_safetensors_layers = {
-  "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit": {
-    "model.safetensors": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
-  },
-  "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit": {
-    "model-00001-of-00008.safetensors": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
-    "model-00002-of-00008.safetensors": [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
-    "model-00003-of-00008.safetensors": [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
-    "model-00004-of-00008.safetensors": [31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42],
-    "model-00005-of-00008.safetensors": [42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53],
-    "model-00006-of-00008.safetensors": [53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64],
-    "model-00007-of-00008.safetensors": [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75],
-    "model-00008-of-00008.safetensors": [75, 76, 77, 78, 79],
-  },
-  "mlx-community/Meta-Llama-3.1-405B-Instruct-4bit": {
-    "model-00001-of-00046.safetensors": [0, 1, 2],
-    "model-00002-of-00046.safetensors": [2, 3, 4, 5],
-    "model-00003-of-00046.safetensors": [5, 6, 7],
-    "model-00004-of-00046.safetensors": [8, 9, 10],
-    "model-00005-of-00046.safetensors": [10, 11, 12, 13],
-    "model-00006-of-00046.safetensors": [13, 14, 15, 16],
-    "model-00007-of-00046.safetensors": [16, 17, 18, 19],
-    "model-00008-of-00046.safetensors": [19, 20, 21],
-    "model-00009-of-00046.safetensors": [22, 23, 24],
-    "model-00010-of-00046.safetensors": [24, 25, 26, 27],
-    "model-00011-of-00046.safetensors": [27, 28, 29, 30],
-    "model-00012-of-00046.safetensors": [30, 31, 32, 33],
-    "model-00013-of-00046.safetensors": [33, 34, 35],
-    "model-00014-of-00046.safetensors": [36, 37, 38],
-    "model-00015-of-00046.safetensors": [38, 39, 40, 41],
-    "model-00016-of-00046.safetensors": [41, 42, 43, 44],
-    "model-00017-of-00046.safetensors": [44, 45, 46, 47],
-    "model-00018-of-00046.safetensors": [47, 48, 49],
-    "model-00019-of-00046.safetensors": [50, 51, 52],
-    "model-00020-of-00046.safetensors": [52, 53, 54, 55],
-    "model-00021-of-00046.safetensors": [55, 56, 57, 58],
-    "model-00022-of-00046.safetensors": [58, 59, 60, 61],
-    "model-00023-of-00046.safetensors": [61, 62, 63],
-    "model-00024-of-00046.safetensors": [64, 65, 66],
-    "model-00025-of-00046.safetensors": [66, 67, 68, 69],
-    "model-00026-of-00046.safetensors": [69, 70, 71, 72],
-    "model-00027-of-00046.safetensors": [72, 73, 74, 75],
-    "model-00028-of-00046.safetensors": [75, 76, 77],
-    "model-00029-of-00046.safetensors": [78, 79, 80],
-    "model-00030-of-00046.safetensors": [80, 81, 82, 83],
-    "model-00031-of-00046.safetensors": [83, 84, 85, 86],
-    "model-00032-of-00046.safetensors": [86, 87, 88, 89],
-    "model-00033-of-00046.safetensors": [89, 90, 91],
-    "model-00034-of-00046.safetensors": [92, 93, 94],
-    "model-00035-of-00046.safetensors": [94, 95, 96, 97],
-    "model-00036-of-00046.safetensors": [97, 98, 99, 100],
-    "model-00037-of-00046.safetensors": [100, 101, 102, 103],
-    "model-00038-of-00046.safetensors": [103, 104, 105],
-    "model-00039-of-00046.safetensors": [106, 107, 108],
-    "model-00040-of-00046.safetensors": [108, 109, 110, 111],
-    "model-00041-of-00046.safetensors": [111, 112, 113, 114],
-    "model-00042-of-00046.safetensors": [114, 115, 116, 117],
-    "model-00043-of-00046.safetensors": [117, 118, 119],
-    "model-00044-of-00046.safetensors": [120, 121, 122],
-    "model-00045-of-00046.safetensors": [122, 123, 124, 125],
-    "model-00046-of-00046.safetensors": [125]
-  },
-  "mlx-community/Mistral-Nemo-Instruct-2407-4bit": {
-    "model-00001-of-00002.safetensors": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32],
-    "model-00002-of-00002.safetensors": [32, 33, 34, 35, 36, 37, 38, 39],
-  },
-  "mlx-community/Mistral-Large-Instruct-2407-4bit": {
-    "model-00001-of-00014.safetensors": [0, 1, 2, 3, 4, 5, 6],
-    "model-00002-of-00014.safetensors": [6, 7, 8, 9, 10, 11, 12, 13],
-    "model-00003-of-00014.safetensors": [13, 14, 15, 16, 17, 18, 19, 20],
-    "model-00004-of-00014.safetensors": [20, 21, 22, 23, 24, 25, 26],
-    "model-00005-of-00014.safetensors": [27, 28, 29, 30, 31, 32, 33],
-    "model-00006-of-00014.safetensors": [33, 34, 35, 36, 37, 38, 39, 40],
-    "model-00007-of-00014.safetensors": [40, 41, 42, 43, 44, 45, 46, 47],
-    "model-00008-of-00014.safetensors": [47, 48, 49, 50, 51, 52, 53, 54],
-    "model-00009-of-00014.safetensors": [54, 55, 56, 57, 58, 59, 60],
-    "model-00010-of-00014.safetensors": [61, 62, 63, 64, 65, 66, 67],
-    "model-00011-of-00014.safetensors": [67, 68, 69, 70, 71, 72, 73, 74],
-    "model-00012-of-00014.safetensors": [74, 75, 76, 77, 78, 79, 80, 81],
-    "model-00013-of-00014.safetensors": [81, 82, 83, 84, 85, 86, 87],
-    "model-00014-of-00014.safetensors": [87]
-  },
-  "llava-hf/llava-1.5-7b-hf": {
-    "model-00001-of-00003.safetensors": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
-    "model-00002-of-00003.safetensors": [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22],
-    "model-00003-of-00003.safetensors": [22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
-  }
-}
-
-def get_safetensors_allow_patterns(repo_id: str, shard: Optional[Shard] = None):
-    return ["*.safetensors"] # TODO: enable this
-    if not shard:
-      return ["*.safetensors"]
-
-    allow_patterns = []
-    for repo_id, safetensors_layers in repo_id_safetensors_layers.items():
-        if repo_id == shard.model_id:
-            for safetensor, layers in safetensors_layers.items():
-                if any(shard.start_layer <= layer <= shard.end_layer for layer in layers):
-                    allow_patterns.append(safetensor)
-
-    return allow_patterns if len(allow_patterns) > 0 else ["*.safetensors"]
-
-async def get_model_path(path_or_hf_repo: str, shard: Optional[Shard] = None, revision: Optional[str] = None, on_download_progress: Callable[[int, int], None] = None) -> Path:
-  """
-  Ensures the model is available locally. If the path does not exist locally,
-  it is downloaded from the Hugging Face Hub.
-
-  Args:
-   path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
-   revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
-
-  Returns:
-   Path: The path to the model.
-  """
-  model_path = Path(path_or_hf_repo)
-  if not model_path.exists():
-    try:
-      model_path = Path(
-        await download_async_with_progress(
-          repo_id=path_or_hf_repo,
-          revision=revision,
-          allow_patterns=[
-            "*.json",
-            "*.py",
-            "tokenizer.model",
-            "*.tiktoken",
-            "*.txt",
-          ] + get_safetensors_allow_patterns(path_or_hf_repo, shard),
-          on_progress=on_download_progress,
-        )
-      )
-    except RepositoryNotFoundError:
-      raise ModelNotFoundError(
-        f"Model not found for path or HF repo: {path_or_hf_repo}.\n"
-        "Please make sure you specified the local path or Hugging Face"
-        " repo id correctly.\nIf you are trying to access a private or"
-        " gated Hugging Face repo, make sure you are authenticated:\n"
-        "https://huggingface.co/docs/huggingface_hub/en/guides/cli#huggingface-cli-login"
-      ) from None
-  return model_path
-
-
 async def load_shard(
 async def load_shard(
-  path_or_hf_repo: str,
+  model_path: str,
   shard: Shard,
   shard: Shard,
   tokenizer_config={},
   tokenizer_config={},
   model_config={},
   model_config={},
   adapter_path: Optional[str] = None,
   adapter_path: Optional[str] = None,
   lazy: bool = False,
   lazy: bool = False,
-  on_download_progress: Callable[[int, int], None] = None,
 ) -> Tuple[nn.Module, TokenizerWrapper]:
 ) -> Tuple[nn.Module, TokenizerWrapper]:
-  """
-  Load the model and tokenizer from a given path or a huggingface repository.
-
-  Args:
-   path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
-   tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
-    Defaults to an empty dictionary.
-   model_config(dict, optional): Configuration parameters specifically for the model.
-    Defaults to an empty dictionary.
-   adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
-    to the model. Default: ``None``.
-   lazy (bool): If False eval the model parameters to make sure they are
-    loaded in memory before returning, otherwise they will be loaded
-    when needed. Default: ``False``
-  Returns:
-   Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
-
-  Raises:
-   FileNotFoundError: If config file or safetensors are not found.
-   ValueError: If model class or args class are not found.
-  """
-  model_path = await get_model_path(path_or_hf_repo, shard, on_download_progress=on_download_progress)
-
   model = load_model_shard(model_path, shard, lazy, model_config)
   model = load_model_shard(model_path, shard, lazy, model_config)
   if adapter_path is not None:
   if adapter_path is not None:
     model = apply_lora_layers(model, adapter_path)
     model = apply_lora_layers(model, adapter_path)

+ 17 - 2
exo/inference/shard.py

@@ -1,13 +1,16 @@
-from dataclasses import dataclass
+from dataclasses import dataclass, field
 
 
 
 
-@dataclass
+@dataclass(frozen=True)
 class Shard:
 class Shard:
   model_id: str
   model_id: str
   start_layer: int
   start_layer: int
   end_layer: int
   end_layer: int
   n_layers: int
   n_layers: int
 
 
+  def __hash__(self):
+    return hash((self.model_id, self.start_layer, self.end_layer, self.n_layers))
+
   def is_first_layer(self) -> bool:
   def is_first_layer(self) -> bool:
     return self.start_layer == 0
     return self.start_layer == 0
 
 
@@ -24,3 +27,15 @@ class Shard:
       "end_layer": self.end_layer,
       "end_layer": self.end_layer,
       "n_layers": self.n_layers,
       "n_layers": self.n_layers,
     }
     }
+
+  def from_dict(data: dict) -> 'Shard':
+    return Shard(**data)
+
+  def overlaps(self, other: 'Shard') -> bool:
+    return shards_overlap(self, other)
+
+def shards_overlap(shard1: Shard, shard2: Shard) -> bool:
+  return (
+      shard1.model_id == shard2.model_id
+      and max(shard1.start_layer, shard2.start_layer) <= min(shard1.end_layer, shard2.end_layer)
+  )

+ 8 - 5
exo/inference/test_inference_engine.py

@@ -1,4 +1,6 @@
+from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
 from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
 from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
+from exo.download.hf.hf_shard_download import HFShardDownloader
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 import asyncio
 import asyncio
@@ -9,6 +11,7 @@ import numpy as np
 async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str):
 async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str):
   prompt = "In a single word only, what is the last name of the current president of the USA?"
   prompt = "In a single word only, what is the last name of the current president of the USA?"
   resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
   resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
+
   next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor(
   next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor(
     "A",
     "A",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
     shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
@@ -42,15 +45,15 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
 
 
 asyncio.run(
 asyncio.run(
   test_inference_engine(
   test_inference_engine(
-    MLXDynamicShardInferenceEngine(),
-    MLXDynamicShardInferenceEngine(),
+    MLXDynamicShardInferenceEngine(HFShardDownloader()),
+    MLXDynamicShardInferenceEngine(HFShardDownloader()),
     "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
     "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
   )
   )
 )
 )
 
 
 # TODO: Need more memory or a smaller model
 # TODO: Need more memory or a smaller model
 # asyncio.run(test_inference_engine(
 # asyncio.run(test_inference_engine(
-#     TinygradDynamicShardInferenceEngine(),
-#     TinygradDynamicShardInferenceEngine(),
-#     "llama3-8b-sfr",
+#     TinygradDynamicShardInferenceEngine(HFShardDownloader()),
+#     TinygradDynamicShardInferenceEngine(HFShardDownloader()),
+#     "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R",
 # ))
 # ))

+ 52 - 221
exo/inference/tinygrad/inference.py

@@ -1,266 +1,97 @@
-import asyncio
-from functools import partial
 from pathlib import Path
 from pathlib import Path
-from typing import List, Optional, Union, Callable
+from typing import List
 import json
 import json
-import tiktoken
-from tiktoken.load import load_tiktoken_bpe
 from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
 from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
-from tinygrad.nn.state import safe_load, torch_load, load_state_dict
-from tinygrad import Tensor, nn, Context, GlobalCounters
-from tinygrad.helpers import DEBUG, tqdm, _cache_dir, fetch
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
+from tinygrad.nn.state import safe_load, torch_load, load_state_dict
+from tinygrad import Tensor, dtypes, nn, Context
+from transformers import AutoTokenizer
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.inference_engine import InferenceEngine
+from typing import Optional, Tuple
 import numpy as np
 import numpy as np
-import os
+from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
+from exo.download.shard_download import ShardDownloader
 
 
+Tensor.no_grad = True
+# default settings
+TEMPERATURE = 0.85
+TOP_K = 25
+TOP_P = 0.9
+ALPHA_F = 0.1
+ALPHA_P = 0.0
 MODEL_PARAMS = {
 MODEL_PARAMS = {
   "8B": {
   "8B": {
-    "args": {
-      "dim": 4096,
-      "n_heads": 32,
-      "n_kv_heads": 8,
-      "n_layers": 32,
-      "norm_eps": 1e-5,
-      "rope_theta": 500000,
-      "vocab_size": 128256,
-      "hidden_dim": 14336,
-    },
-    "files": 1,
+    "args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336},
+    "files": 1
   },
   },
   "70B": {
   "70B": {
-    "args": {
-      "dim": 8192,
-      "n_heads": 64,
-      "n_kv_heads": 8,
-      "n_layers": 80,
-      "norm_eps": 1e-5,
-      "rope_theta": 500000,
-      "vocab_size": 128256,
-      "hidden_dim": 28672,
-    },
-    "files": 8,
-  },
+    "args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256,  "hidden_dim": 28672},
+    "files": 8
+  }
 }
 }
 
 
-
-
-# **** helper functions ****
-async def fetch_async(
-  url: str,
-  name: Optional[Union[Path, str]] = None,
-  subdir: Optional[str] = None,
-  allow_caching=not os.getenv("DISABLE_HTTP_CACHE"),
-) -> Path:
-  func = partial(fetch, url, name, subdir, allow_caching)
-  return await asyncio.get_event_loop().run_in_executor(None, func)
-
-
-def concat_weights(models, device=None):
-  def convert(name) -> Tensor:
-    disk_tensors: List[Tensor] = [model[name] for model in models]
-    if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1:
-      return disk_tensors[0].to(device=device)
-    axis = 1 if name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0
-    lazy_tensors = [data.to(device=device) for data in disk_tensors]
-    return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
-
-  return {name: convert(name) for name in {name: None for model in models for name in model}}
-
-
-def load(fn: str):
-  if fn.endswith(".index.json"):
-    with open(fn) as fp:
-      weight_map = json.load(fp)["weight_map"]
-    parts = {n: load(str(Path(fn).parent / Path(n).name)) for n in set(weight_map.values())}
-    return {k: parts[n][k] for k, n in weight_map.items()}
-  elif fn.endswith(".safetensors"):
-    return safe_load(fn)
-  else:
-    return torch_load(fn)
-
-
-def build_transformer(model_path: Path, shard: Shard, model_size="8B", quantize=None, device=None):
+def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=None):
   # build model
   # build model
   linear = nn.Linear
   linear = nn.Linear
   with Context(THREEFRY=0):
   with Context(THREEFRY=0):
-    model = Transformer(**MODEL_PARAMS[model_size]["args"], shard=shard, linear=linear, max_context=8192, jit=False)
+    model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True, shard=shard)
 
 
   # load weights
   # load weights
   if model_path.is_dir():
   if model_path.is_dir():
-    if (model_path / "model.safetensors.index.json").exists():
-      weights = load(str(model_path / "model.safetensors.index.json"))
-    elif (model_path / "model.safetensors").exists():
-      weights = load(str(model_path / "model.safetensors"))
-    else:
-      weights = concat_weights(
-        [load(str(model_path / f"consolidated.{i:02d}.pth")) for i in range(MODEL_PARAMS[model_size]["files"])],
-        device[0] if isinstance(device, tuple) else device,
-      )
+    if (model_path / "model.safetensors.index.json").exists(): weights = load(str(model_path / "model.safetensors.index.json"), shard)
+    elif (model_path / "model.safetensors").exists(): weights = load(str(model_path / "model.safetensors"), shard)
+    else: weights = concat_weights([load(str(model_path / f"consolidated.{i:02d}.pth"), shard) for i in range(MODEL_PARAMS[model_size]["files"])], device[0] if isinstance(device, tuple) else device)
   else:
   else:
-    weights = load(str(model_path))
+    weights = load(str(model_path), shard)
   if "model.embed_tokens.weight" in weights:
   if "model.embed_tokens.weight" in weights:
-    weights = convert_from_huggingface(
-      weights,
-      model,
-      MODEL_PARAMS[model_size]["args"]["n_heads"],
-      MODEL_PARAMS[model_size]["args"]["n_kv_heads"],
-      shard=shard,
-    )
+    weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"])
   weights = fix_bf16(weights)
   weights = fix_bf16(weights)
 
 
   with Context(BEAM=0):
   with Context(BEAM=0):
-    # quantize
-    if quantize is not None:
-      weights = linear.quantize(weights, device)
-      for _, v in weights.items():
-        v.realize()
-
-    # shard
-    if isinstance(device, tuple):
-      for k, v in nn.state.get_state_dict(model).items():
-        if "scale" in k:
-          v.shard_(device, axis=None)  # from quantized
-        elif ".attention." in k:
-          v.shard_(device, axis=-1)
-        elif ".feed_forward.w1." in k:
-          v.shard_(device, axis=0)
-        elif ".feed_forward.w3." in k:
-          v.shard_(device, axis=0)
-        elif ".feed_forward." in k:
-          v.shard_(device, axis=-1)
-        elif "tok_embeddings.weight" in k:
-          v.shard_(device, axis=0)
-        elif "output.weight" in k:
-          v.shard_(device, axis=0)
-        else:
-          v.shard_(device, axis=None)
-
     # replace weights in model
     # replace weights in model
-    load_state_dict(model, weights, strict=False, consume=True)
+    load_state_dict(model, weights, strict=False, consume=False) # consume=True
   return model
   return model
 
 
-
-# default settings
-TEMPERATURE = 0  # 0.85
-TOP_K = 25
-TOP_P = 0.9
-ALPHA_F = 0.1
-ALPHA_P = 0.0
-
-
-def prefill(model, toks, start_pos=0):
-  # prefill the model
-  for tok in tqdm(toks):
-    GlobalCounters.reset()
-    model(Tensor([[tok]]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).realize()
-    start_pos += 1
-  return start_pos
-
-
 class TinygradDynamicShardInferenceEngine(InferenceEngine):
 class TinygradDynamicShardInferenceEngine(InferenceEngine):
-  def __init__(self):
+  def __init__(self, shard_downloader: ShardDownloader):
     self.shard = None
     self.shard = None
+    self.shard_downloader = shard_downloader
 
 
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
-    # TODO: we need to refactor models/llamaa to handle per-request-kv-cache. right now it's shared between requests.
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
-    start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0
+    start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
+    n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
 
 
     toks = self.tokenizer.encode(prompt)
     toks = self.tokenizer.encode(prompt)
-    start_pos = prefill(self.model, toks[:-1], start_pos=start_pos)
-    last_tok = toks[-1]
-
-    output_data = np.array([self.model(Tensor([[last_tok]]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist()])
-    if output_data.size == 1:
-      start_pos += 1
+    h = self.model(Tensor([toks]), start_pos, TEMPERATURE)
 
 
-    return (
-      output_data,
-      json.dumps({"start_pos": start_pos}),
-      output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id],
-    )
+    if h.shape == (1,):
+      start_pos += len(toks)
+      n_captured_toks = 1
+      return np.array([[h.item()]]), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), h.item() == self.tokenizer.eos_token_id
+    else:
+      n_captured_toks += len(toks)
+      return h.numpy(), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks + 1}), False
 
 
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
-    start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0
+    start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
+    n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
 
 
-    output_data: np.ndarray = np.array([self.model(Tensor([input_data]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist()])
-    if output_data.size == 1:
-      start_pos += 1
+    h = self.model(Tensor(input_data), start_pos, TEMPERATURE).realize()
 
 
-    return (
-      output_data,
-      json.dumps({"start_pos": start_pos}),
-      output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id],
-    )
+    if h.shape == (1,):
+      start_pos += n_captured_toks
+      n_captured_toks = 1
+      return np.array([[h.item()]]), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), h.item() == self.tokenizer.eos_token_id
+    else:
+      return h.numpy(), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), False
 
 
   async def ensure_shard(self, shard: Shard):
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:
     if self.shard == shard:
       return
       return
 
 
-    model_path = Path(shard.model_id)
-    models_dir = Path(_cache_dir) / "tinygrad" / "downloads"
-    model_path = models_dir / shard.model_id
-    size = "8B"
-    if Path(model_path / "tokenizer_config.json").exists():
-      model = model_path
-    else:
-
-      if DEBUG >= 2: print(f"Downloading tinygrad model {shard.model_id}...")
-      if shard.model_id.lower().find("llama3-8b-sfr") != -1:
-        num_files = 4
-        for i in range(num_files):
-          await fetch_async(
-            f"https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/model-{(i+1):05d}-of-{num_files:05d}.safetensors",
-            f"model-{(i+1):05d}-of-{num_files:05d}.safetensors",
-            subdir=shard.model_id,
-          )
-        await fetch_async(
-          "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/config.json",
-          "config.json",
-          subdir=shard.model_id,
-        )
-        model = await fetch_async(
-          "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/raw/main/model.safetensors.index.json",
-          "model.safetensors.index.json",
-          subdir=shard.model_id,
-        )
-        await fetch_async(
-          "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/special_tokens_map.json",
-          "special_tokens_map.json",
-          subdir=shard.model_id,
-        )
-        await fetch_async(
-          "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/tokenizer.json",
-          "tokenizer.json",
-          subdir=shard.model_id,
-        )
-        await fetch_async(
-          "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/tokenizer_config.json",
-          "tokenizer_config.json",
-          subdir=shard.model_id,
-        )
-        size = "8B"
-      elif shard.model_id.lower().find("llama3-70b-sfr") != -1:
-        raise NotImplementedError("llama3-70b-sfr is not implemented for tinygrad")
-        # fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-70B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir=shard.model_id)
-        # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00001-of-00004.safetensors", "model-00001-of-00004.safetensors", subdir=shard.model_id)
-        # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00002-of-00004.safetensors", "model-00002-of-00004.safetensors", subdir=shard.model_id)
-        # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00003-of-00004.safetensors", "model-00003-of-00004.safetensors", subdir=shard.model_id)
-        # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00004-of-00004.safetensors", "model-00004-of-00004.safetensors", subdir=shard.model_id)
-        # model = fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/raw/main/model.safetensors.index.json", "model.safetensors.index.json", subdir=shard.model_id)
-        # size = "70B"
-      else:
-        raise ValueError(f"tinygrad doesnt currently support arbitrary model downloading. unsupported model: {shard.model_id}")
-
-    model = build_transformer(model_path, shard=shard, model_size=size)
-    from transformers import AutoTokenizer
-    tokenizer = AutoTokenizer.from_pretrained(str((model_path if model_path.is_dir() else model_path.parent)))
-
+    model_path = await self.shard_downloader.ensure_shard(shard)
+    self.model = build_transformer(model_path, shard, model_size="8B")
+    self.tokenizer = AutoTokenizer.from_pretrained(str((model_path if model_path.is_dir() else model_path.parent)))
     self.shard = shard
     self.shard = shard
-    self.model = model
-    self.tokenizer = tokenizer
-
-  def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]):
-    pass

+ 53 - 125
exo/inference/tinygrad/models/llama.py

@@ -1,26 +1,22 @@
 from typing import Tuple, Union, Optional, Dict, Any
 from typing import Tuple, Union, Optional, Dict, Any
 from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
 from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
 from tinygrad.helpers import getenv
 from tinygrad.helpers import getenv
-from exo.inference.shard import Shard
-
 
 
 # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
 # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
 def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half) -> Tensor:
 def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half) -> Tensor:
-  freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[: (dim // 2)] / dim))
+  freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
   freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
   freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
   # TODO: move dtype outside this
   # TODO: move dtype outside this
-  return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim // 2, 2)
-
+  return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim//2, 2)
 
 
 # (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
 # (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
 def complex_mult(A, c, d):
 def complex_mult(A, c, d):
-  a, b = A[..., 0:1], A[..., 1:2]
-  ro = a * c - b * d
-  co = a * d + b * c
+  a,b = A[..., 0:1], A[..., 1:2]
+  ro = a*c - b*d
+  co = a*d + b*c
   return ro.cat(co, dim=-1)
   return ro.cat(co, dim=-1)
 
 
-
-def apply_rotary_emb(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> Tuple[Tensor, Tensor]:
+def apply_rotary_emb(xq:Tensor, xk:Tensor, freqs_cis:Tensor) -> Tuple[Tensor, Tensor]:
   assert freqs_cis.shape[1] == xq.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
   assert freqs_cis.shape[1] == xq.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
   xq = xq.reshape(*xq.shape[0:-1], -1, 2)
   xq = xq.reshape(*xq.shape[0:-1], -1, 2)
   xk = xk.reshape(*xk.shape[0:-1], -1, 2)
   xk = xk.reshape(*xk.shape[0:-1], -1, 2)
@@ -30,19 +26,16 @@ def apply_rotary_emb(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> Tuple[Tensor,
   xk_out = complex_mult(xk, c, d)
   xk_out = complex_mult(xk, c, d)
   return xq_out.flatten(3), xk_out.flatten(3)
   return xq_out.flatten(3), xk_out.flatten(3)
 
 
-
-def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
+def repeat_kv(x:Tensor, n_rep:int) -> Tensor:
   bs, seqlen, n_kv_heads, head_dim = x.shape
   bs, seqlen, n_kv_heads, head_dim = x.shape
-  if n_rep == 1:
-    return x
+  if n_rep == 1: return x
   # NOTE: this is different from x.repeat((1, 1, n_rep, 1))
   # NOTE: this is different from x.repeat((1, 1, n_rep, 1))
   return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
   return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
 
 
-
 class Attention:
 class Attention:
   def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
   def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
     self.n_heads = n_heads
     self.n_heads = n_heads
-    self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads  # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
+    self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
     self.head_dim = dim // n_heads
     self.head_dim = dim // n_heads
     self.n_rep = self.n_heads // self.n_kv_heads
     self.n_rep = self.n_heads // self.n_kv_heads
     self.max_context = max_context
     self.max_context = max_context
@@ -52,8 +45,14 @@ class Attention:
     self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
     self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
     self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
     self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
 
 
-  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]) -> Tensor:
-    xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
+  def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]) -> Tensor:
+    if getenv("WQKV"):
+      if not hasattr(self, 'wqkv'): self.wqkv = Tensor.cat(self.wq.weight, self.wk.weight, self.wv.weight)
+      xqkv = x @ self.wqkv.T
+      xq, xk, xv = xqkv.split([self.wq.weight.shape[0], self.wk.weight.shape[0], self.wv.weight.shape[0]], dim=2)
+    else:
+      xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
+
     xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
     xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
     xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
     xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
     xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
     xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
@@ -66,14 +65,14 @@ class Attention:
       self.cache_kv = Tensor.zeros(2, bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype).contiguous().realize()
       self.cache_kv = Tensor.zeros(2, bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype).contiguous().realize()
       if isinstance(x.device, tuple):
       if isinstance(x.device, tuple):
         # TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded
         # TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded
-        self.cache_kv.shard_((x.device), axis=None).realize()
+        self.cache_kv.shard_((x.device), axis=3 if getenv("SHARD_KVCACHE") else None).realize()
 
 
     # update the cache
     # update the cache
     assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}"
     assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}"
-    self.cache_kv.shrink((None, None, (start_pos, start_pos + seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
+    self.cache_kv.shrink((None, None, (start_pos, start_pos+seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
 
 
-    keys = self.cache_kv[0].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xk
-    values = self.cache_kv[1].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xv
+    keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xk
+    values = self.cache_kv[1].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xv
 
 
     keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
     keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
     xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
     xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
@@ -81,39 +80,26 @@ class Attention:
     attn = attn.reshape(bsz, seqlen, -1)
     attn = attn.reshape(bsz, seqlen, -1)
     return self.wo(attn)
     return self.wo(attn)
 
 
-
 class FeedForward:
 class FeedForward:
-  def __init__(self, dim: int, hidden_dim: int, linear=nn.Linear):
+  def __init__(self, dim:int, hidden_dim:int, linear=nn.Linear):
     self.w1 = linear(dim, hidden_dim, bias=False)
     self.w1 = linear(dim, hidden_dim, bias=False)
     self.w2 = linear(hidden_dim, dim, bias=False)
     self.w2 = linear(hidden_dim, dim, bias=False)
-    self.w3 = linear(dim, hidden_dim, bias=False)  # the gate in Gated Linear Unit
-
-  def __call__(self, x: Tensor) -> Tensor:
-    return self.w2(self.w1(x).silu() * self.w3(x))  # SwiGLU [arxiv/2002.05202, eq (5)]
+    self.w3 = linear(dim, hidden_dim, bias=False) # the gate in Gated Linear Unit
 
 
+  def __call__(self, x:Tensor) -> Tensor:
+    return self.w2(self.w1(x).silu() * self.w3(x)) # SwiGLU [arxiv/2002.05202, eq (5)]
 
 
 class TransformerBlock:
 class TransformerBlock:
-  def __init__(
-    self,
-    dim: int,
-    hidden_dim: int,
-    n_heads: int,
-    n_kv_heads: int,
-    norm_eps: float,
-    max_context: int,
-    linear=nn.Linear,
-    feed_forward=FeedForward,
-  ):
+  def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int, linear=nn.Linear, feed_forward=FeedForward):
     self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear)
     self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear)
     self.feed_forward = feed_forward(dim, hidden_dim, linear)
     self.feed_forward = feed_forward(dim, hidden_dim, linear)
     self.attention_norm = nn.RMSNorm(dim, norm_eps)
     self.attention_norm = nn.RMSNorm(dim, norm_eps)
     self.ffn_norm = nn.RMSNorm(dim, norm_eps)
     self.ffn_norm = nn.RMSNorm(dim, norm_eps)
 
 
-  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]):
+  def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]):
     h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
     h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
     return (h + self.feed_forward(self.ffn_norm(h))).contiguous()
     return (h + self.feed_forward(self.ffn_norm(h))).contiguous()
 
 
-
 # standard openai sampling
 # standard openai sampling
 def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
 def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
   assert logits.ndim == 1, "only works on 1d tensors"
   assert logits.ndim == 1, "only works on 1d tensors"
@@ -121,8 +107,7 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
   assert 0 <= k <= logits.numel(), "k must be between 0 and numel"
   assert 0 <= k <= logits.numel(), "k must be between 0 and numel"
 
 
   # if temperature is very low just use argmax
   # if temperature is very low just use argmax
-  if temp < 1e-6:
-    return logits.argmax()
+  if temp < 1e-6: return logits.argmax()
 
 
   # alpha sampling
   # alpha sampling
   if af or ap:
   if af or ap:
@@ -136,16 +121,10 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
   # softmax
   # softmax
   t = (logits / temp).softmax()
   t = (logits / temp).softmax()
 
 
-  counter, counter2 = (
-    Tensor.arange(t.numel(), device=logits.device).contiguous(),
-    Tensor.arange(t.numel() - 1, -1, -1, device=logits.device).contiguous(),
-  )
+  counter, counter2 = Tensor.arange(t.numel(), device=logits.device).contiguous(), Tensor.arange(t.numel() - 1, -1, -1, device=logits.device).contiguous()
   # top k
   # top k
   if k:
   if k:
-    output, output_indices = (
-      Tensor.zeros(k, device=logits.device).contiguous(),
-      Tensor.zeros(k, device=logits.device, dtype=dtypes.int32).contiguous(),
-    )
+    output, output_indices = Tensor.zeros(k, device=logits.device).contiguous(), Tensor.zeros(k, device=logits.device, dtype=dtypes.int32).contiguous()
     for i in range(k):
     for i in range(k):
       t_argmax = (t.numel() - ((t == (t_max := t.max())) * counter2).max() - 1).cast(dtypes.default_int)
       t_argmax = (t.numel() - ((t == (t_max := t.max())) * counter2).max() - 1).cast(dtypes.default_int)
       output = output + t_max.unsqueeze(0).pad(((i, k - i - 1),))
       output = output + t_max.unsqueeze(0).pad(((i, k - i - 1),))
@@ -170,84 +149,48 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
 
 
   return output_token
   return output_token
 
 
+from exo.inference.shard import Shard
 
 
 class Transformer:
 class Transformer:
-  def __init__(
-    self,
-    dim: int,
-    hidden_dim: int,
-    n_heads: int,
-    n_layers: int,
-    norm_eps: float,
-    vocab_size,
-    shard: Shard,
-    linear=nn.Linear,
-    n_kv_heads=None,
-    rope_theta=10000,
-    max_context=1024,
-    jit=True,
-    feed_forward=FeedForward,
-  ):
-    self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(shard.end_layer - shard.start_layer + 1)]
+  def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, shard: Shard=None, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward):
+    self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
     self.norm = nn.RMSNorm(dim, norm_eps)
     self.norm = nn.RMSNorm(dim, norm_eps)
     self.tok_embeddings = nn.Embedding(vocab_size, dim)
     self.tok_embeddings = nn.Embedding(vocab_size, dim)
     self.output = nn.Linear(dim, vocab_size, bias=False)
     self.output = nn.Linear(dim, vocab_size, bias=False)
     self.max_context = max_context
     self.max_context = max_context
-    self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context * 2, rope_theta)
+    self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context * 2, rope_theta).contiguous()
     self.forward_jit = TinyJit(self.forward) if jit else None
     self.forward_jit = TinyJit(self.forward) if jit else None
     self.shard = shard
     self.shard = shard
 
 
-  def forward(
-    self,
-    h: Tensor,
-    start_pos: Union[Variable, int],
-    temperature: float,
-    top_k: int,
-    top_p: float,
-    alpha_f: float,
-    alpha_p: float,
-  ):
-    seqlen = h.shape[1]
-    freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos + seqlen), None, None, None))
+  def forward(self, x:Tensor, start_pos:Union[Variable,int], temperature:float, top_k:int, top_p:float, alpha_f:float, alpha_p:float):
+    seqlen = x.shape[1]
+    freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos+seqlen),None,None,None))
+    mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-100000000"), dtype=x.dtype, device=x.device).triu(start_pos+1).realize() if seqlen > 1 else None
 
 
     if self.shard.is_first_layer():
     if self.shard.is_first_layer():
-      h = self.tok_embeddings(h)
-    mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos + 1).realize() if seqlen > 1 else None
+      h = self.tok_embeddings(x)
+    else:
+      h = x
 
 
-    for layer in self.layers:
+    for i in range(self.shard.start_layer, self.shard.end_layer + 1):
+      layer = self.layers[i]
       h = layer(h, start_pos, freqs_cis, mask)
       h = layer(h, start_pos, freqs_cis, mask)
 
 
     if self.shard.is_last_layer():
     if self.shard.is_last_layer():
       logits = self.output(self.norm(h)).float()[:, -1, :]
       logits = self.output(self.norm(h)).float()[:, -1, :]
       return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p).realize()
       return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p).realize()
     else:
     else:
-      return h.realize()
-
-  def __call__(
-    self,
-    tokens: Tensor,
-    start_pos: Variable,
-    temperature: float = 0.0,
-    top_k: int = 0,
-    top_p: float = 0.8,
-    alpha_f: float = 0.0,
-    alpha_p: float = 0.0,
-  ):
+      return h
+
+  def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0, top_k:int=0, top_p:float=0.8, alpha_f:float=0.0, alpha_p:float=0.0):
     # TODO: better way to handle the first call v.s. the rest?
     # TODO: better way to handle the first call v.s. the rest?
-    # if tokens.shape[0:2] == (1,1) and self.forward_jit is not None:
-    #   return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
+    if tokens.shape[0:2] == (1,1) and self.forward_jit is not None:
+      return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
     return self.forward(tokens, start_pos, temperature, top_k, top_p, alpha_f, alpha_p)
     return self.forward(tokens, start_pos, temperature, top_k, top_p, alpha_f, alpha_p)
 
 
-  def reset(self):
-    for layer in self.layers:
-      if hasattr(layer.attention, "cache_kv"):
-        layer.attention.cache_kv = layer.attention.cache_kv.zeros_like()
-
-
 # *** helpers ***
 # *** helpers ***
 
 
-
-def convert_from_huggingface(weights: Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int, shard: Shard):
+def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int):
   def permute(v: Tensor, n_heads: int):
   def permute(v: Tensor, n_heads: int):
     return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
     return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
 
 
@@ -255,30 +198,16 @@ def convert_from_huggingface(weights: Dict[str, Tensor], model: Transformer, n_h
     "model.embed_tokens.weight": "tok_embeddings.weight",
     "model.embed_tokens.weight": "tok_embeddings.weight",
     **{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))},
     **{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))},
     **{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
     **{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
-    **{f"model.layers.{l}.self_attn.{x}_proj.biases": f"layers.{l}.attention.w{x}.bias" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
-    **{f"model.layers.{l}.self_attn.{x}_proj.scales": f"layers.{l}.attention.w{x}.scale" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
     **{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(len(model.layers))},
     **{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(len(model.layers))},
-    **{f"model.layers.{l}.post_attention_layernorm.biases": f"layers.{l}.ffn_norm.bias" for l in range(len(model.layers))},
     **{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
     **{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
-    **{f"model.layers.{l}.mlp.{x}_proj.biases": f"layers.{l}.feed_forward.w{y}.bias" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
-    **{f"model.layers.{l}.mlp.{x}_proj.scales": f"layers.{l}.feed_forward.w{y}.scale" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
     "model.norm.weight": "norm.weight",
     "model.norm.weight": "norm.weight",
     "lm_head.weight": "output.weight",
     "lm_head.weight": "output.weight",
-    "lm_head.biases": "output.bias",
-    "lm_head.scales": "output.scale",
   }
   }
   sd = {}
   sd = {}
   for k, v in weights.items():
   for k, v in weights.items():
-    if ".rotary_emb." in k:
-      continue
+    if ".rotary_emb." in k: continue
     v = v.to(Device.DEFAULT)
     v = v.to(Device.DEFAULT)
     if "model.layers" in k:
     if "model.layers" in k:
-      layer_num = int(k.split(".")[2])
-      if shard.start_layer <= layer_num <= shard.end_layer:
-        k = f"model.layers.{layer_num - shard.start_layer}." + ".".join(k.split(".")[3:])
-      else:
-        continue
-
       if "q_proj" in k:
       if "q_proj" in k:
         v = permute(v, n_heads)
         v = permute(v, n_heads)
       elif "k_proj" in k:
       elif "k_proj" in k:
@@ -286,10 +215,9 @@ def convert_from_huggingface(weights: Dict[str, Tensor], model: Transformer, n_h
     sd[keymap[k]] = v
     sd[keymap[k]] = v
   return sd
   return sd
 
 
-
-def fix_bf16(weights: Dict[Any, Tensor]):
+def fix_bf16(weights:Dict[Any, Tensor]):
   if getenv("SUPPORT_BF16", 1):
   if getenv("SUPPORT_BF16", 1):
     # TODO: without casting to float16, 70B llama OOM on tinybox.
     # TODO: without casting to float16, 70B llama OOM on tinybox.
-    return {k: v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
+    return {k:v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}
   # TODO: check if device supports bf16
   # TODO: check if device supports bf16
-  return {k: v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
+  return {k:v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}

+ 38 - 0
exo/inference/tinygrad/tinygrad_helpers.py

@@ -0,0 +1,38 @@
+from tinygrad.nn.state import safe_load, torch_load
+from tinygrad import Tensor
+from pathlib import Path
+import json
+from typing import List
+from exo.inference.shard import Shard
+from exo.helpers import DEBUG
+
+# **** helper functions ****
+def concat_weights(models, device=None):
+  def convert(name) -> Tensor:
+    disk_tensors: List[Tensor] = [model[name] for model in models]
+    if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1:
+      return disk_tensors[0].to(device=device)
+    axis = 1 if name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0
+    lazy_tensors = [data.to(device=device) for data in disk_tensors]
+    return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
+  return {name: convert(name) for name in {name: None for model in models for name in model}}
+
+def load(fn:str, shard: Shard):
+  if fn.endswith('.index.json'):
+    with open(fn) as fp: weight_map = json.load(fp)['weight_map']
+    parts = {}
+    filtered_weight_map = {}
+    for k, n in weight_map.items():
+      if k.startswith("model.layers."):
+        layer_num = int(k.split('.')[2])
+        if layer_num < shard.start_layer or layer_num > shard.end_layer:
+          continue
+
+      parts[n] = load(str(Path(fn).parent / Path(n).name), shard)
+      filtered_weight_map[k] = n
+    if DEBUG >= 2: print(f"Excluded model param keys for {shard=}: {set(weight_map.keys()) - set(filtered_weight_map.keys())}")
+    return {k: parts[n][k] for k, n in filtered_weight_map.items()}
+  elif fn.endswith(".safetensors"):
+    return safe_load(fn)
+  else:
+    return torch_load(fn)

+ 6 - 6
exo/networking/grpc/grpc_server.py

@@ -48,7 +48,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     image_str = request.image_str
     image_str = request.image_str
     request_id = request.request_id
     request_id = request.request_id
     result = await self.node.process_prompt(shard, prompt, image_str, request_id)
     result = await self.node.process_prompt(shard, prompt, image_str, request_id)
-    if DEBUG >= 2: print(f"SendPrompt {shard=} {prompt=} {image_str=} {request_id=} result: {result}")
+    if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {image_str=} {request_id=} result: {result}")
     tensor_data = result.tobytes() if result is not None else None
     tensor_data = result.tobytes() if result is not None else None
     return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
     return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
 
 
@@ -64,14 +64,14 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     inference_state = request.inference_state
     inference_state = request.inference_state
 
 
     result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
     result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
-    if DEBUG >= 2: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
+    if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
     tensor_data = result.tobytes() if result is not None else None
     tensor_data = result.tobytes() if result is not None else None
     return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
     return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
 
 
   async def GetInferenceResult(self, request, context):
   async def GetInferenceResult(self, request, context):
     request_id = request.request_id
     request_id = request.request_id
     result = await self.node.get_inference_result(request_id)
     result = await self.node.get_inference_result(request_id)
-    if DEBUG >= 2: print(f"GetInferenceResult {request_id=}: {result}")
+    if DEBUG >= 5: print(f"GetInferenceResult {request_id=}: {result}")
     tensor_data = result[0].tobytes() if result[0] is not None else None
     tensor_data = result[0].tobytes() if result[0] is not None else None
     return (
     return (
       node_service_pb2.InferenceResult(
       node_service_pb2.InferenceResult(
@@ -96,20 +96,20 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
       for node_id, cap in topology.nodes.items()
       for node_id, cap in topology.nodes.items()
     }
     }
     peer_graph = {node_id: node_service_pb2.Peers(peer_ids=peers) for node_id, peers in topology.peer_graph.items()}
     peer_graph = {node_id: node_service_pb2.Peers(peer_ids=peers) for node_id, peers in topology.peer_graph.items()}
-    if DEBUG >= 2: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
+    if DEBUG >= 5: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
     return node_service_pb2.Topology(nodes=nodes, peer_graph=peer_graph)
     return node_service_pb2.Topology(nodes=nodes, peer_graph=peer_graph)
 
 
   async def SendResult(self, request, context):
   async def SendResult(self, request, context):
     request_id = request.request_id
     request_id = request.request_id
     result = request.result
     result = request.result
     is_finished = request.is_finished
     is_finished = request.is_finished
-    if DEBUG >= 2: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}")
+    if DEBUG >= 5: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}")
     self.node.on_token.trigger_all(request_id, result, is_finished)
     self.node.on_token.trigger_all(request_id, result, is_finished)
     return node_service_pb2.Empty()
     return node_service_pb2.Empty()
 
 
   async def SendOpaqueStatus(self, request, context):
   async def SendOpaqueStatus(self, request, context):
     request_id = request.request_id
     request_id = request.request_id
     status = request.status
     status = request.status
-    if DEBUG >= 2: print(f"Received SendOpaqueStatus request: {request_id=} {status=}")
+    if DEBUG >= 5: print(f"Received SendOpaqueStatus request: {request_id=} {status=}")
     self.node.on_opaque_status.trigger_all(request_id, status)
     self.node.on_opaque_status.trigger_all(request_id, status)
     return node_service_pb2.Empty()
     return node_service_pb2.Empty()

+ 14 - 21
exo/orchestration/standard_node.py

@@ -3,6 +3,7 @@ import json
 import asyncio
 import asyncio
 import uuid
 import uuid
 import time
 import time
+import traceback
 from typing import List, Dict, Optional, Tuple, Union
 from typing import List, Dict, Optional, Tuple, Union
 from exo.networking import Discovery, PeerHandle, Server
 from exo.networking import Discovery, PeerHandle, Server
 from exo.inference.inference_engine import InferenceEngine, Shard
 from exo.inference.inference_engine import InferenceEngine, Shard
@@ -13,6 +14,7 @@ from exo.topology.partitioning_strategy import Partition, PartitioningStrategy,
 from exo import DEBUG
 from exo import DEBUG
 from exo.helpers import AsyncCallbackSystem
 from exo.helpers import AsyncCallbackSystem
 from exo.viz.topology_viz import TopologyViz
 from exo.viz.topology_viz import TopologyViz
+from exo.download.hf.hf_helpers import RepoProgressEvent
 
 
 
 
 class StandardNode(Node):
 class StandardNode(Node):
@@ -23,7 +25,7 @@ class StandardNode(Node):
     inference_engine: InferenceEngine,
     inference_engine: InferenceEngine,
     discovery: Discovery,
     discovery: Discovery,
     partitioning_strategy: PartitioningStrategy = None,
     partitioning_strategy: PartitioningStrategy = None,
-    max_generate_tokens: int = 256,
+    max_generate_tokens: int = 1024,
     chatgpt_api_endpoint: Optional[str] = None,
     chatgpt_api_endpoint: Optional[str] = None,
     web_chat_url: Optional[str] = None,
     web_chat_url: Optional[str] = None,
     disable_tui: Optional[bool] = False,
     disable_tui: Optional[bool] = False,
@@ -42,6 +44,7 @@ class StandardNode(Node):
     self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
     self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
     self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
     self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
     self._on_opaque_status.register("node_status").on_next(self.on_node_status)
     self._on_opaque_status.register("node_status").on_next(self.on_node_status)
+    self.node_download_progress: Dict[str, RepoProgressEvent] = {}
 
 
   def on_node_status(self, request_id, opaque_status):
   def on_node_status(self, request_id, opaque_status):
     try:
     try:
@@ -54,13 +57,14 @@ class StandardNode(Node):
             self.current_topology.active_node_id = None
             self.current_topology.active_node_id = None
       download_progress = None
       download_progress = None
       if status_data.get("type", "") == "download_progress":
       if status_data.get("type", "") == "download_progress":
-        if DEBUG >= 5: print(f"Download progress from {status_data.get('node_id')}: {status_data.get('current')}/{status_data.get('total')} ({round(status_data.get('current') / status_data.get('total') * 100, 2)}%)")
-        if status_data.get("node_id") == self.id:
-          download_progress = (status_data.get('current'), status_data.get('total'))
+        if DEBUG >= 5: print(f"Download progress from {status_data.get('node_id')}: {status_data.get('progress')}")
+        download_progress = RepoProgressEvent.from_dict(status_data.get('progress'))
+        self.node_download_progress[status_data.get('node_id')] = download_progress
       if self.topology_viz:
       if self.topology_viz:
-        self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), download_progress)
-    except json.JSONDecodeError:
-      pass
+        self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), self.id, self.node_download_progress)
+    except Exception as e:
+      if DEBUG >= 1: print(f"Error updating visualization: {e}")
+      if DEBUG >= 1: traceback.print_exc()
 
 
   async def start(self, wait_for_peers: int = 0) -> None:
   async def start(self, wait_for_peers: int = 0) -> None:
     await self.server.start()
     await self.server.start()
@@ -231,8 +235,6 @@ class StandardNode(Node):
       return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
       return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
     except Exception as e:
     except Exception as e:
       print(f"Error processing tensor for shard {shard}: {e}")
       print(f"Error processing tensor for shard {shard}: {e}")
-      import traceback
-
       traceback.print_exc()
       traceback.print_exc()
       return None
       return None
 
 
@@ -287,15 +289,14 @@ class StandardNode(Node):
 
 
   async def update_peers(self, wait_for_peers: int = 0) -> None:
   async def update_peers(self, wait_for_peers: int = 0) -> None:
     self.peers = await self.discovery.discover_peers(wait_for_peers)
     self.peers = await self.discovery.discover_peers(wait_for_peers)
-    if DEBUG >= 2: print(f"Starting with the following peers: {self.peers}")
-    if DEBUG >= 2: print("Connecting to new peers...")
     for peer in self.peers:
     for peer in self.peers:
       is_connected = await peer.is_connected()
       is_connected = await peer.is_connected()
       if DEBUG >= 2 and is_connected:
       if DEBUG >= 2 and is_connected:
         print(f"Already connected to {peer.id()}: {is_connected}")
         print(f"Already connected to {peer.id()}: {is_connected}")
       if not is_connected:
       if not is_connected:
+        if DEBUG >= 2: print(f"Connecting to {peer.id()}...")
         await peer.connect()
         await peer.connect()
-        if DEBUG >= 0: print(f"Connected to peer {peer.device_capabilities()} ({peer.id()=})")
+        if DEBUG >= 1: print(f"Connected to peer {peer.device_capabilities()} ({peer.id()=})")
 
 
   async def periodic_topology_collection(self, interval: int):
   async def periodic_topology_collection(self, interval: int):
     while True:
     while True:
@@ -306,9 +307,6 @@ class StandardNode(Node):
       except Exception as e:
       except Exception as e:
         print(f"Error collecting topology: {e}")
         print(f"Error collecting topology: {e}")
 
 
-      if DEBUG >= 2: print("Topology collection task executed.")
-      if DEBUG >= 2: print(f"Current topology: {self.topology}")
-
   async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
   async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
     if request_id not in self.buffered_token_output:
     if request_id not in self.buffered_token_output:
       return None, False
       return None, False
@@ -328,7 +326,6 @@ class StandardNode(Node):
       next_topology.add_edge(self.id, peer.id())
       next_topology.add_edge(self.id, peer.id())
 
 
       if peer.id() in prev_visited:
       if peer.id() in prev_visited:
-        if DEBUG >= 2: print(f"Already visited {peer.id()}. Skipping...")
         continue
         continue
 
 
       if max_depth <= 0:
       if max_depth <= 0:
@@ -345,7 +342,7 @@ class StandardNode(Node):
     next_topology.active_node_id = self.topology.active_node_id  # this is not so clean.
     next_topology.active_node_id = self.topology.active_node_id  # this is not so clean.
     self.topology = next_topology
     self.topology = next_topology
     if self.topology_viz:
     if self.topology_viz:
-      self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology))
+      self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), self.id)
     return next_topology
     return next_topology
 
 
   @property
   @property
@@ -368,8 +365,6 @@ class StandardNode(Node):
         print(f"Timeout broadcasting result to {peer.id()}")
         print(f"Timeout broadcasting result to {peer.id()}")
       except Exception as e:
       except Exception as e:
         print(f"Error broadcasting result to {peer.id()}: {e}")
         print(f"Error broadcasting result to {peer.id()}: {e}")
-        import traceback
-
         traceback.print_exc()
         traceback.print_exc()
 
 
     await asyncio.gather(*[send_result_to_peer(peer) for peer in self.peers], return_exceptions=True)
     await asyncio.gather(*[send_result_to_peer(peer) for peer in self.peers], return_exceptions=True)
@@ -383,8 +378,6 @@ class StandardNode(Node):
         print(f"Timeout sending opaque status to {peer.id()}")
         print(f"Timeout sending opaque status to {peer.id()}")
       except Exception as e:
       except Exception as e:
         print(f"Error sending opaque status to {peer.id()}: {e}")
         print(f"Error sending opaque status to {peer.id()}: {e}")
-        import traceback
-
         traceback.print_exc()
         traceback.print_exc()
 
 
     await asyncio.gather(*[send_status_to_peer(peer) for peer in self.peers], return_exceptions=True)
     await asyncio.gather(*[send_status_to_peer(peer) for peer in self.peers], return_exceptions=True)

+ 67 - 1
exo/viz/test_topology_viz.py

@@ -1,9 +1,63 @@
 import asyncio
 import asyncio
 import unittest
 import unittest
+from datetime import timedelta
 from exo.viz.topology_viz import TopologyViz
 from exo.viz.topology_viz import TopologyViz
 from exo.topology.topology import Topology
 from exo.topology.topology import Topology
 from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
 from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
 from exo.topology.partitioning_strategy import Partition
 from exo.topology.partitioning_strategy import Partition
+from exo.download.hf.hf_helpers import RepoProgressEvent, RepoFileProgressEvent
+
+
+def create_hf_repo_progress_event(
+    completed_files: int = 5,
+    total_files: int = 10,
+    downloaded_bytes: int = 500000000,
+    downloaded_bytes_this_session: int = 250000000,
+    total_bytes: int = 1000000000,
+    overall_speed: int = 5000000,
+    overall_eta: timedelta = timedelta(seconds=100),
+    file_progress: dict = None,
+    status: str = "in_progress"
+) -> RepoProgressEvent:
+    if file_progress is None:
+        file_progress = {
+            "file1.bin": RepoFileProgressEvent(
+                repo_id="repo_id",
+                repo_revision="repo_revision",
+                file_path="file1.bin",
+                downloaded=100000000,
+                downloaded_this_session=50000000,
+                total=200000000,
+                speed=1000000,
+                eta=timedelta(seconds=100),
+                status="in_progress"
+            ),
+            "file2.bin": RepoFileProgressEvent(
+                repo_id="repo_id",
+                repo_revision="repo_revision",
+                file_path="file2.bin",
+                downloaded=200000000,
+                downloaded_this_session=100000000,
+                total=200000000,
+                speed=2000000,
+                eta=timedelta(seconds=0),
+                status="complete"
+            )
+        }
+
+    return RepoProgressEvent(
+        repo_id="repo_id",
+        repo_revision="repo_revision",
+        completed_files=completed_files,
+        total_files=total_files,
+        downloaded_bytes=downloaded_bytes,
+        downloaded_bytes_this_session=downloaded_bytes_this_session,
+        total_bytes=total_bytes,
+        overall_speed=overall_speed,
+        overall_eta=overall_eta,
+        file_progress=file_progress,
+        status=status
+    )
 
 
 
 
 class TestNodeViz(unittest.IsolatedAsyncioTestCase):
 class TestNodeViz(unittest.IsolatedAsyncioTestCase):
@@ -30,7 +84,7 @@ class TestNodeViz(unittest.IsolatedAsyncioTestCase):
     await asyncio.sleep(2)  # Simulate running for a short time
     await asyncio.sleep(2)  # Simulate running for a short time
 
 
   async def test_layout_generation(self):
   async def test_layout_generation(self):
-    self.top_viz._generate_layout()
+    # self.top_viz._generate_layout()
     self.top_viz.refresh()
     self.top_viz.refresh()
     import time
     import time
 
 
@@ -43,6 +97,13 @@ class TestNodeViz(unittest.IsolatedAsyncioTestCase):
         Partition("node2", 0.4, 0.8),
         Partition("node2", 0.4, 0.8),
         Partition("node3", 0.8, 0.9),
         Partition("node3", 0.8, 0.9),
       ],
       ],
+      "node1",
+      {
+        "node1": create_hf_repo_progress_event(),
+        "node2": create_hf_repo_progress_event(),
+        "node3": create_hf_repo_progress_event(),
+        "node4": create_hf_repo_progress_event(),
+      },
     )
     )
     time.sleep(2)
     time.sleep(2)
     self.topology.active_node_id = "node3"
     self.topology.active_node_id = "node3"
@@ -54,6 +115,11 @@ class TestNodeViz(unittest.IsolatedAsyncioTestCase):
         Partition("node2", 0.5, 0.7),
         Partition("node2", 0.5, 0.7),
         Partition("node4", 0.7, 0.9),
         Partition("node4", 0.7, 0.9),
       ],
       ],
+      "node5",
+      {
+        "node1": create_hf_repo_progress_event(),
+        "node5": create_hf_repo_progress_event(),
+      },
     )
     )
     time.sleep(2)
     time.sleep(2)
 
 

+ 110 - 38
exo/viz/topology_viz.py

@@ -1,51 +1,72 @@
 import math
 import math
-from typing import List, Optional, Tuple
-from exo.helpers import exo_text
+from typing import List, Optional, Tuple, Dict
+from exo.helpers import exo_text, pretty_print_bytes, pretty_print_bytes_per_second
 from exo.topology.topology import Topology
 from exo.topology.topology import Topology
 from exo.topology.partitioning_strategy import Partition
 from exo.topology.partitioning_strategy import Partition
+from exo.download.hf.hf_helpers import RepoProgressEvent
 from rich.console import Console
 from rich.console import Console
 from rich.panel import Panel
 from rich.panel import Panel
 from rich.text import Text
 from rich.text import Text
 from rich.live import Live
 from rich.live import Live
 from rich.style import Style
 from rich.style import Style
+from rich.table import Table
+from rich.layout import Layout
 from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
 from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
 
 
-
 class TopologyViz:
 class TopologyViz:
   def __init__(self, chatgpt_api_endpoint: str = None, web_chat_url: str = None):
   def __init__(self, chatgpt_api_endpoint: str = None, web_chat_url: str = None):
     self.chatgpt_api_endpoint = chatgpt_api_endpoint
     self.chatgpt_api_endpoint = chatgpt_api_endpoint
     self.web_chat_url = web_chat_url
     self.web_chat_url = web_chat_url
     self.topology = Topology()
     self.topology = Topology()
     self.partitions: List[Partition] = []
     self.partitions: List[Partition] = []
-    self.download_progress = None
+    self.node_id = None
+    self.node_download_progress: Dict[str, RepoProgressEvent] = {}
 
 
     self.console = Console()
     self.console = Console()
-    self.panel = Panel(self._generate_layout(), title="Exo Cluster (0 nodes)", border_style="bright_yellow")
-    self.live_panel = Live(self.panel, auto_refresh=False, console=self.console)
+    self.layout = Layout()
+    self.layout.split(
+      Layout(name="main"),
+      Layout(name="download", size=15)
+    )
+    self.main_panel = Panel(self._generate_main_layout(), title="Exo Cluster (0 nodes)", border_style="bright_yellow")
+    self.download_panel = Panel("", title="Download Progress", border_style="cyan")
+    self.layout["main"].update(self.main_panel)
+    self.layout["download"].update(self.download_panel)
+    self.live_panel = Live(self.layout, auto_refresh=False, console=self.console)
     self.live_panel.start()
     self.live_panel.start()
 
 
-  def update_visualization(self, topology: Topology, partitions: List[Partition], download_progress: Optional[Tuple[int, int]] = None):
+  def update_visualization(self, topology: Topology, partitions: List[Partition], node_id: Optional[str] = None, node_download_progress: Dict[str, RepoProgressEvent] = {}):
     self.topology = topology
     self.topology = topology
     self.partitions = partitions
     self.partitions = partitions
-    self.download_progress = download_progress
+    self.node_id = node_id
+    if node_download_progress:
+      self.node_download_progress = node_download_progress
     self.refresh()
     self.refresh()
 
 
   def refresh(self):
   def refresh(self):
-    self.panel.renderable = self._generate_layout()
+    self.main_panel.renderable = self._generate_main_layout()
     # Update the panel title with the number of nodes and partitions
     # Update the panel title with the number of nodes and partitions
     node_count = len(self.topology.nodes)
     node_count = len(self.topology.nodes)
-    self.panel.title = f"Exo Cluster ({node_count} node{'s' if node_count != 1 else ''}){f' {self.download_progress[0]/self.download_progress[1]:.2%} Downloaded' if self.download_progress else ''}"
-    self.live_panel.update(self.panel, refresh=True)
+    self.main_panel.title = f"Exo Cluster ({node_count} node{'s' if node_count != 1 else ''})"
+
+    # Only show download_panel if there are in-progress downloads
+    if any(progress.status == "in_progress" for progress in self.node_download_progress.values()):
+      self.download_panel.renderable = self._generate_download_layout()
+      self.layout["download"].visible = True
+    else:
+      self.layout["download"].visible = False
+
+    self.live_panel.update(self.layout, refresh=True)
 
 
-  def _generate_layout(self) -> str:
+  def _generate_main_layout(self) -> str:
     # Calculate visualization parameters
     # Calculate visualization parameters
     num_partitions = len(self.partitions)
     num_partitions = len(self.partitions)
-    radius_x = 30  # Increased horizontal radius
-    radius_y = 12  # Decreased vertical radius
-    center_x, center_y = 50, 28  # Centered horizontally and moved up slightly
+    radius_x = 30
+    radius_y = 12
+    center_x, center_y = 50, 24  # Increased center_y to add more space
 
 
     # Generate visualization
     # Generate visualization
-    visualization = [[" " for _ in range(100)] for _ in range(55)]  # Decreased height
+    visualization = [[" " for _ in range(100)] for _ in range(48)]  # Increased height to 48
 
 
     # Add exo_text at the top in bright yellow
     # Add exo_text at the top in bright yellow
     exo_lines = exo_text.split("\n")
     exo_lines = exo_text.split("\n")
@@ -53,7 +74,7 @@ class TopologyViz:
     max_line_length = max(len(line) for line in exo_lines)
     max_line_length = max(len(line) for line in exo_lines)
     for i, line in enumerate(exo_lines):
     for i, line in enumerate(exo_lines):
       centered_line = line.center(max_line_length)
       centered_line = line.center(max_line_length)
-      start_x = (100 - max_line_length) // 2 + 15  # Center the text plus empirical adjustment of 15
+      start_x = (100 - max_line_length) // 2 + 15
       colored_line = Text(centered_line, style=yellow_style)
       colored_line = Text(centered_line, style=yellow_style)
       for j, char in enumerate(str(colored_line)):
       for j, char in enumerate(str(colored_line)):
         if 0 <= start_x + j < 100 and i < len(visualization):
         if 0 <= start_x + j < 100 and i < len(visualization):
@@ -68,9 +89,9 @@ class TopologyViz:
 
 
     info_start_y = len(exo_lines) + 1
     info_start_y = len(exo_lines) + 1
     for i, line in enumerate(info_lines):
     for i, line in enumerate(info_lines):
-      start_x = (100 - len(line)) // 2 + 15  # Center the info lines plus empirical adjustment of 15
+      start_x = (100 - len(line)) // 2 + 15
       for j, char in enumerate(line):
       for j, char in enumerate(line):
-        if 0 <= start_x + j < 100 and info_start_y + i < 55:
+        if 0 <= start_x + j < 100 and info_start_y + i < 48:
           visualization[info_start_y + i][start_x + j] = char
           visualization[info_start_y + i][start_x + j] = char
 
 
     # Calculate total FLOPS and position on the bar
     # Calculate total FLOPS and position on the bar
@@ -78,13 +99,13 @@ class TopologyViz:
     bar_pos = (math.tanh(total_flops / 20 - 2) + 1) / 2
     bar_pos = (math.tanh(total_flops / 20 - 2) + 1) / 2
 
 
     # Add GPU poor/rich bar
     # Add GPU poor/rich bar
-    bar_width = 30  # Increased bar width
-    bar_start_x = (100 - bar_width) // 2  # Center the bar
-    bar_y = info_start_y + len(info_lines) + 1  # Position the bar below the info section with two cells of space
+    bar_width = 30
+    bar_start_x = (100 - bar_width) // 2
+    bar_y = info_start_y + len(info_lines) + 1
 
 
     # Create a gradient bar using emojis
     # Create a gradient bar using emojis
     gradient_bar = Text()
     gradient_bar = Text()
-    emojis = ["🟥", "🟧", "🟨", "🟩"]  # Red, Orange, Yellow, Green
+    emojis = ["🟥", "🟧", "🟨", "🟩"]
     for i in range(bar_width):
     for i in range(bar_width):
       emoji_index = min(int(i / (bar_width / len(emojis))), len(emojis) - 1)
       emoji_index = min(int(i / (bar_width / len(emojis))), len(emojis) - 1)
       gradient_bar.append(emojis[emoji_index])
       gradient_bar.append(emojis[emoji_index])
@@ -106,6 +127,9 @@ class TopologyViz:
     visualization[bar_y + 1][pos_x - len(flops_str) // 2 : pos_x + len(flops_str) // 2 + len(flops_str) % 2] = flops_str
     visualization[bar_y + 1][pos_x - len(flops_str) // 2 : pos_x + len(flops_str) // 2 + len(flops_str) % 2] = flops_str
     visualization[bar_y + 2][pos_x] = "▲"
     visualization[bar_y + 2][pos_x] = "▲"
 
 
+    # Add an extra empty line for spacing
+    bar_y += 4
+
     for i, partition in enumerate(self.partitions):
     for i, partition in enumerate(self.partitions):
       device_capabilities = self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES)
       device_capabilities = self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES)
 
 
@@ -113,11 +137,13 @@ class TopologyViz:
       x = int(center_x + radius_x * math.cos(angle))
       x = int(center_x + radius_x * math.cos(angle))
       y = int(center_y + radius_y * math.sin(angle))
       y = int(center_y + radius_y * math.sin(angle))
 
 
-      # Place node with different color for active node
+      # Place node with different color for active node and this node
       if partition.node_id == self.topology.active_node_id:
       if partition.node_id == self.topology.active_node_id:
-        visualization[y][x] = "🔴"  # Red circle for active node
+        visualization[y][x] = "🔴"
+      elif partition.node_id == self.node_id:
+        visualization[y][x] = "🟢"
       else:
       else:
-        visualization[y][x] = "🔵"  # Blue circle for inactive nodes
+        visualization[y][x] = "🔵"
 
 
       # Place node info (model, memory, TFLOPS, partition) on three lines
       # Place node info (model, memory, TFLOPS, partition) on three lines
       node_info = [
       node_info = [
@@ -127,28 +153,27 @@ class TopologyViz:
       ]
       ]
 
 
       # Calculate info position based on angle
       # Calculate info position based on angle
-      info_distance_x = radius_x + 6  # Increased horizontal distance
-      info_distance_y = radius_y + 3  # Decreased vertical distance
+      info_distance_x = radius_x + 6
+      info_distance_y = radius_y + 3
       info_x = int(center_x + info_distance_x * math.cos(angle))
       info_x = int(center_x + info_distance_x * math.cos(angle))
       info_y = int(center_y + info_distance_y * math.sin(angle))
       info_y = int(center_y + info_distance_y * math.sin(angle))
 
 
       # Adjust text position to avoid overwriting the node icon and prevent cutoff
       # Adjust text position to avoid overwriting the node icon and prevent cutoff
-      if info_x < x:  # Text is to the left of the node
+      if info_x < x:
         info_x = max(0, x - len(max(node_info, key=len)) - 1)
         info_x = max(0, x - len(max(node_info, key=len)) - 1)
-      elif info_x > x:  # Text is to the right of the node
+      elif info_x > x:
         info_x = min(99 - len(max(node_info, key=len)), info_x)
         info_x = min(99 - len(max(node_info, key=len)), info_x)
 
 
       # Adjust for top and bottom nodes
       # Adjust for top and bottom nodes
-      if 5 * math.pi / 4 < angle < 7 * math.pi / 4:  # Node is near the top
-        info_x += 4  # Shift text slightly to the right
-      elif math.pi / 4 < angle < 3 * math.pi / 4:  # Node is near the bottom
-        info_x += 3  # Shift text slightly to the right
-        info_y -= 2  # Move text up by two cells
+      if 5 * math.pi / 4 < angle < 7 * math.pi / 4:
+        info_x += 4
+      elif math.pi / 4 < angle < 3 * math.pi / 4:
+        info_x += 3
+        info_y -= 2
 
 
       for j, line in enumerate(node_info):
       for j, line in enumerate(node_info):
         for k, char in enumerate(line):
         for k, char in enumerate(line):
-          if 0 <= info_y + j < 55 and 0 <= info_x + k < 100:  # Updated height check
-            # Ensure we're not overwriting the node icon
+          if 0 <= info_y + j < 48 and 0 <= info_x + k < 100:
             if info_y + j != y or info_x + k != x:
             if info_y + j != y or info_x + k != x:
               visualization[info_y + j][info_x + k] = char
               visualization[info_y + j][info_x + k] = char
 
 
@@ -163,8 +188,55 @@ class TopologyViz:
       for step in range(1, steps):
       for step in range(1, steps):
         line_x = int(x + (next_x - x) * step / steps)
         line_x = int(x + (next_x - x) * step / steps)
         line_y = int(y + (next_y - y) * step / steps)
         line_y = int(y + (next_y - y) * step / steps)
-        if 0 <= line_y < 55 and 0 <= line_x < 100:  # Updated height check
+        if 0 <= line_y < 48 and 0 <= line_x < 100:
           visualization[line_y][line_x] = "-"
           visualization[line_y][line_x] = "-"
 
 
     # Convert to string
     # Convert to string
     return "\n".join("".join(str(char) for char in row) for row in visualization)
     return "\n".join("".join(str(char) for char in row) for row in visualization)
+
+  def _generate_download_layout(self) -> Table:
+    summary = Table(show_header=False, box=None, padding=(0, 1), expand=True)
+    summary.add_column("Info", style="cyan", no_wrap=True, ratio=50)
+    summary.add_column("Progress", style="cyan", no_wrap=True, ratio=40)
+    summary.add_column("Percentage", style="cyan", no_wrap=True, ratio=10)
+
+    # Current node download progress
+    if self.node_id in self.node_download_progress:
+        download_progress = self.node_download_progress[self.node_id]
+        title = f"Downloading model {download_progress.repo_id}@{download_progress.repo_revision} ({download_progress.completed_files}/{download_progress.total_files}):"
+        summary.add_row(Text(title, style="bold"))
+        progress_info = f"{pretty_print_bytes(download_progress.downloaded_bytes)} / {pretty_print_bytes(download_progress.total_bytes)} ({pretty_print_bytes_per_second(download_progress.overall_speed)})"
+        summary.add_row(progress_info)
+
+        eta_info = f"{download_progress.overall_eta}"
+        summary.add_row(eta_info)
+
+        summary.add_row("")  # Empty row for spacing
+
+        for file_path, file_progress in download_progress.file_progress.items():
+            if file_progress.status != "complete":
+                progress = int(file_progress.downloaded / file_progress.total * 30)
+                bar = f"[{'=' * progress}{' ' * (30 - progress)}]"
+                percentage = f"{file_progress.downloaded / file_progress.total * 100:.0f}%"
+                summary.add_row(Text(file_path[:30], style="cyan"), bar, percentage)
+
+    summary.add_row("")  # Empty row for spacing
+
+    # Other nodes download progress summary
+    summary.add_row(Text("Other Nodes Download Progress:", style="bold"))
+    for node_id, progress in self.node_download_progress.items():
+        if node_id != self.node_id:
+            device = self.topology.nodes.get(node_id)
+            partition = next((p for p in self.partitions if p.node_id == node_id), None)
+            partition_info = f"[{partition.start:.2f}-{partition.end:.2f}]" if partition else ""
+            percentage = progress.downloaded_bytes / progress.total_bytes * 100 if progress.total_bytes > 0 else 0
+            speed = pretty_print_bytes_per_second(progress.overall_speed)
+            device_info = f"{device.model if device else 'Unknown Device'} {device.memory // 1024 if device else '?'}GB {partition_info}"
+            progress_info = f"{progress.repo_id}@{progress.repo_revision} ({speed})"
+            progress_bar = f"[{'=' * int(percentage // 3.33)}{' ' * (30 - int(percentage // 3.33))}]"
+            percentage_str = f"{percentage:.1f}%"
+            eta_str = f"{progress.overall_eta}"
+            summary.add_row(device_info, progress_info, percentage_str)
+            summary.add_row("", progress_bar, eta_str)
+
+    return summary

+ 53 - 0
extra/download_hf.py

@@ -0,0 +1,53 @@
+import argparse
+import asyncio
+from exo.download.hf.hf_helpers import download_all_files, RepoProgressEvent
+
+DEFAULT_ALLOW_PATTERNS = [
+    "*.json",
+    "*.py",
+    "tokenizer.model",
+    "*.tiktoken",
+    "*.txt",
+    "*.safetensors",
+]
+# Always ignore `.git` and `.cache/huggingface` folders in commits
+DEFAULT_IGNORE_PATTERNS = [
+    ".git",
+    ".git/*",
+    "*/.git",
+    "**/.git/**",
+    ".cache/huggingface",
+    ".cache/huggingface/*",
+    "*/.cache/huggingface",
+    "**/.cache/huggingface/**",
+]
+
+async def main(repo_id, revision="main", allow_patterns=None, ignore_patterns=None):
+    async def progress_callback(event: RepoProgressEvent):
+        print(f"Overall Progress: {event.completed_files}/{event.total_files} files, {event.downloaded_bytes}/{event.total_bytes} bytes")
+        print(f"Estimated time remaining: {event.overall_eta}")
+        print("File Progress:")
+        for file_path, progress in event.file_progress.items():
+            status_icon = {
+                'not_started': '⚪',
+                'in_progress': '🔵',
+                'complete': '✅'
+            }[progress.status]
+            eta_str = str(progress.eta)
+            print(f"{status_icon} {file_path}: {progress.downloaded}/{progress.total} bytes, "
+                  f"Speed: {progress.speed:.2f} B/s, ETA: {eta_str}")
+        print("\n")
+
+    await download_all_files(repo_id, revision, progress_callback, allow_patterns, ignore_patterns)
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description="Download files from a Hugging Face model repository.")
+    parser.add_argument("--repo-id", required=True, help="The repository ID (e.g., 'meta-llama/Meta-Llama-3.1-8B-Instruct')")
+    parser.add_argument("--revision", default="main", help="The revision to download (branch, tag, or commit hash)")
+    parser.add_argument("--allow-patterns", nargs="*", default=None, help="Patterns of files to allow (e.g., '*.json' '*.safetensors')")
+    parser.add_argument("--ignore-patterns", nargs="*", default=None, help="Patterns of files to ignore (e.g., '.*')")
+
+    args = parser.parse_args()
+
+    asyncio.run(main(args.repo_id, args.revision, args.allow_patterns, args.ignore_patterns))

+ 30 - 5
main.py

@@ -2,13 +2,17 @@ import argparse
 import asyncio
 import asyncio
 import signal
 import signal
 import json
 import json
-import uuid
+import time
+import traceback
 from exo.orchestration.standard_node import StandardNode
 from exo.orchestration.standard_node import StandardNode
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.grpc.grpc_discovery import GRPCDiscovery
 from exo.networking.grpc.grpc_discovery import GRPCDiscovery
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 from exo.api import ChatGPTAPI
 from exo.api import ChatGPTAPI
+from exo.download.shard_download import ShardDownloader, RepoProgressEvent
+from exo.download.hf.hf_shard_download import HFShardDownloader
 from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_inference_engine, get_system_info, get_or_create_node_id
 from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_inference_engine, get_system_info, get_or_create_node_id
+from exo.inference.shard import Shard
 
 
 # parse args
 # parse args
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
@@ -22,7 +26,7 @@ parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery
 parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
 parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
 parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
 parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
 parser.add_argument("--chatgpt-api-response-timeout-secs", type=int, default=90, help="ChatGPT API response timeout in seconds")
 parser.add_argument("--chatgpt-api-response-timeout-secs", type=int, default=90, help="ChatGPT API response timeout in seconds")
-parser.add_argument("--max-generate-tokens", type=int, default=256, help="Max tokens to generate in each request")
+parser.add_argument("--max-generate-tokens", type=int, default=1024, help="Max tokens to generate in each request")
 parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use")
 parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use")
 parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
 parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
 args = parser.parse_args()
 args = parser.parse_args()
@@ -32,9 +36,10 @@ print_yellow_exo()
 system_info = get_system_info()
 system_info = get_system_info()
 print(f"Detected system: {system_info}")
 print(f"Detected system: {system_info}")
 
 
+shard_downloader: ShardDownloader = HFShardDownloader()
 inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
 inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
-inference_engine = get_inference_engine(inference_engine_name)
-print(f"Using inference engine: {inference_engine.__class__.__name__}")
+inference_engine = get_inference_engine(inference_engine_name, shard_downloader)
+print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")
 
 
 if args.node_port is None:
 if args.node_port is None:
     args.node_port = find_available_port(args.node_host)
     args.node_port = find_available_port(args.node_host)
@@ -57,10 +62,30 @@ server = GRPCServer(node, args.node_host, args.node_port)
 node.server = server
 node.server = server
 api = ChatGPTAPI(node, inference_engine.__class__.__name__, response_timeout_secs=args.chatgpt_api_response_timeout_secs)
 api = ChatGPTAPI(node, inference_engine.__class__.__name__, response_timeout_secs=args.chatgpt_api_response_timeout_secs)
 node.on_token.register("main_log").on_next(lambda _, tokens, __: print(inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens))
 node.on_token.register("main_log").on_next(lambda _, tokens, __: print(inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens))
+def preemptively_start_download(request_id: str, opaque_status: str):
+    try:
+        status = json.loads(opaque_status)
+        if status.get("type") == "node_status" and status.get("status") == "start_process_prompt":
+            current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
+            if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
+            asyncio.create_task(shard_downloader.ensure_shard(current_shard))
+    except Exception as e:
+        if DEBUG >= 2:
+            print(f"Failed to preemptively start download: {e}")
+            traceback.print_exc()
+node.on_opaque_status.register("start_download").on_next(preemptively_start_download)
 if args.prometheus_client_port:
 if args.prometheus_client_port:
     from exo.stats.metrics import start_metrics_server
     from exo.stats.metrics import start_metrics_server
     start_metrics_server(node, args.prometheus_client_port)
     start_metrics_server(node, args.prometheus_client_port)
-inference_engine.set_on_download_progress(lambda current, total: asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "current": current, "total": total}))))
+
+last_broadcast_time = 0
+def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
+    global last_broadcast_time
+    current_time = time.time()
+    if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
+        last_broadcast_time = current_time
+        asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
+shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
 
 
 async def shutdown(signal, loop):
 async def shutdown(signal, loop):
     """Gracefully shutdown the server and close the asyncio loop."""
     """Gracefully shutdown the server and close the asyncio loop."""

+ 3 - 1
setup.py

@@ -6,6 +6,7 @@ from setuptools import find_packages, setup
 install_requires = [
 install_requires = [
     "aiohttp==3.9.5",
     "aiohttp==3.9.5",
     "aiohttp_cors==0.7.0",
     "aiohttp_cors==0.7.0",
+    "aiofiles==24.1.0",
     "blobfile==2.1.1",
     "blobfile==2.1.1",
     "grpcio==1.64.1",
     "grpcio==1.64.1",
     "grpcio-tools==1.64.1",
     "grpcio-tools==1.64.1",
@@ -21,6 +22,7 @@ install_requires = [
     "requests==2.32.3",
     "requests==2.32.3",
     "rich==13.7.1",
     "rich==13.7.1",
     "safetensors==0.4.3",
     "safetensors==0.4.3",
+    "tenacity==9.0.0",
     "tiktoken==0.7.0",
     "tiktoken==0.7.0",
     "tokenizers==0.19.1",
     "tokenizers==0.19.1",
     "tqdm==4.66.4",
     "tqdm==4.66.4",
@@ -33,7 +35,7 @@ install_requires = [
 if sys.platform.startswith("darwin"):
 if sys.platform.startswith("darwin"):
     install_requires.extend(
     install_requires.extend(
         [
         [
-            "mlx==0.16.0",
+            "mlx==0.16.1",
             "mlx-lm==0.16.1",
             "mlx-lm==0.16.1",
         ]
         ]
     )
     )