Browse Source

separate hf_helpers, make extra dir with download_hf script, unify downloading so tinygrad uses the same method as mlx and interoperable model formats

Alex Cheema 9 months ago
parent
commit
545a486ed3

+ 2 - 19
exo/api/chatgpt_api.py

@@ -25,11 +25,11 @@ shard_mappings = {
   },
   "llama-3-8b": {
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-8b-sfr", start_layer=0, end_layer=0, n_layers=32),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct", start_layer=0, end_layer=0, n_layers=32),
   },
   "llama-3-70b": {
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-70b-sfr", start_layer=0, end_layer=0, n_layers=80),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct", start_layer=0, end_layer=0, n_layers=80),
   },
   ### mistral
   "mistral-nemo": {
@@ -76,14 +76,6 @@ class ChatCompletionRequest:
         }
 
 
-def resolve_tinygrad_tokenizer(model_id: str):
-  if model_id == "llama3-8b-sfr":
-    return AutoTokenizer.from_pretrained("TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R")
-  elif model_id == "llama3-70b-sfr":
-    return AutoTokenizer.from_pretrained("TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R")
-  else:
-    raise ValueError(f"tinygrad doesnt currently support arbitrary model downloading. unsupported model: {model_id}")
-
 
 async def resolve_tokenizer(model_id: str):
   try:
@@ -111,15 +103,6 @@ async def resolve_tokenizer(model_id: str):
 
     if DEBUG >= 2: print(traceback.format_exc())
 
-  try:
-    if DEBUG >= 2: print(f"Trying tinygrad tokenizer for {model_id}")
-    return resolve_tinygrad_tokenizer(model_id)
-  except Exception as e:
-    if DEBUG >= 2: print(f"Failed again to load tokenizer for {model_id}. Falling back to mlx tokenizer. Error: {e}")
-    import traceback
-
-    if DEBUG >= 2: print(traceback.format_exc())
-
   if DEBUG >= 2: print(f"Trying mlx tokenizer for {model_id}")
   from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
 

+ 95 - 114
hf_async.py → exo/inference/hf_helpers.py

@@ -1,36 +1,17 @@
 import asyncio
 import aiohttp
 import os
-import argparse
 from urllib.parse import urljoin
-from typing import Callable, Optional, Coroutine, Any
+from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
 from datetime import datetime, timedelta
 from fnmatch import fnmatch
 from pathlib import Path
-from typing import Generator, Iterable, List, TypeVar, Union
+from typing import Generator, Iterable, TypeVar, TypedDict
+from dataclasses import dataclass
+from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
+from exo.helpers import DEBUG
 
 T = TypeVar("T")
-
-DEFAULT_ALLOW_PATTERNS = [
-    "*.json",
-    "*.py",
-    "tokenizer.model",
-    "*.tiktoken",
-    "*.txt",
-    "*.safetensors",
-]
-# Always ignore `.git` and `.cache/huggingface` folders in commits
-DEFAULT_IGNORE_PATTERNS = [
-    ".git",
-    ".git/*",
-    "*/.git",
-    "**/.git/**",
-    ".cache/huggingface",
-    ".cache/huggingface/*",
-    "*/.cache/huggingface",
-    "**/.cache/huggingface/**",
-]
-
 def filter_repo_objects(
     items: Iterable[T],
     *,
@@ -117,7 +98,35 @@ async def fetch_file_list(session, repo_id, revision, path=""):
         else:
             raise Exception(f"Failed to fetch file list: {response.status}")
 
-async def download_file(session, repo_id, revision, file_path, save_directory, progress_callback: Optional[Callable[[str, int, int, float, timedelta], Coroutine[Any, Any, None]]] = None):
+
+@dataclass
+class HFRepoFileProgressEvent:
+    file_path: str
+    downloaded: int
+    total: int
+    speed: float
+    eta: timedelta
+    status: Literal["not_started", "in_progress", "complete"]
+
+@dataclass
+class HFRepoProgressEvent:
+    completed_files: int
+    total_files: int
+    downloaded_bytes: int
+    total_bytes: int
+    overall_eta: timedelta
+    file_progress: Dict[str, HFRepoFileProgressEvent]
+
+HFRepoFileProgressCallback = Callable[[HFRepoFileProgressEvent], Coroutine[Any, Any, None]]
+HFRepoProgressCallback = Callable[[HFRepoProgressEvent], Coroutine[Any, Any, None]]
+
+@retry(
+    stop=stop_after_attempt(5),
+    wait=wait_exponential(multiplier=1, min=4, max=60),
+    retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError, aiohttp.ClientResponseError)),
+    reraise=True
+)
+async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[HFRepoFileProgressCallback] = None, use_range_request: bool = True):
     base_url = f"https://huggingface.co/{repo_id}/resolve/{revision}/"
     url = urljoin(base_url, file_path)
     local_path = os.path.join(save_directory, file_path)
@@ -125,64 +134,72 @@ async def download_file(session, repo_id, revision, file_path, save_directory, p
     os.makedirs(os.path.dirname(local_path), exist_ok=True)
 
     # Check if file already exists and get its size
-    if os.path.exists(local_path):
-        local_file_size = os.path.getsize(local_path)
-    else:
-        local_file_size = 0
+    local_file_size = os.path.getsize(local_path) if os.path.exists(local_path) else 0
+
+    headers = get_auth_headers()
+    if use_range_request:
+        headers["Range"] = f"bytes={local_file_size}-"
 
-    headers = {"Range": f"bytes={local_file_size}-"}
-    headers.update(get_auth_headers())
     async with session.get(url, headers=headers) as response:
+        total_size = int(response.headers.get('Content-Length', 0))
+        downloaded_size = local_file_size
+        mode = 'ab' if use_range_request else 'wb'
+        if downloaded_size == total_size:
+            if DEBUG >= 2: print(f"File already downloaded: {file_path}")
+            if progress_callback:
+                await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, total_size, 0, timedelta(0), "complete"))
+            return
+
         if response.status == 200:
-            # File doesn't support range requests, start from beginning
+            # File doesn't support range requests or we're not using them, start from beginning
             mode = 'wb'
-            total_size = int(response.headers.get('Content-Length', 0))
             downloaded_size = 0
         elif response.status == 206:
             # Partial content, resume download
-            mode = 'ab'
-            content_range = response.headers.get('Content-Range')
-            total_size = int(content_range.split('/')[-1])
-            downloaded_size = local_file_size
+            content_range = response.headers.get('Content-Range', '')
+            try:
+                total_size = int(content_range.split('/')[-1])
+            except ValueError:
+                if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
+                return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
         elif response.status == 416:
             # Range not satisfiable, get the actual file size
-            if response.headers.get('Content-Type', '').startswith('text/html'):
-                content = await response.text()
-                print(f"Response content (HTML):\n{content}")
-            else:
-                print(response)
-            print("Return header: ", response.headers)
-            print("Return header: ", response.headers.get('Content-Range').split('/')[-1])
-            total_size = int(response.headers.get('Content-Range', '').split('/')[-1])
-            if local_file_size == total_size:
-                print(f"File already fully downloaded: {file_path}")
-                return
-            else:
-                # Start the download from the beginning
-                mode = 'wb'
-                downloaded_size = 0
+            content_range = response.headers.get('Content-Range', '')
+            try:
+                total_size = int(content_range.split('/')[-1])
+                if downloaded_size == total_size:
+                    if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}")
+                    if progress_callback:
+                        await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, total_size, 0, timedelta(0), "complete"))
+                    return
+            except ValueError:
+                if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
+                return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
         else:
-            print(f"Failed to download {file_path}: {response.status}")
-            return
+            raise aiohttp.ClientResponseError(response.request_info, response.history, status=response.status, message=f"Failed to download {file_path}: {response.status}")
 
         if downloaded_size == total_size:
             print(f"File already downloaded: {file_path}")
+            if progress_callback:
+                await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, total_size, 0, timedelta(0), "complete"))
             return
 
+        DOWNLOAD_CHUNK_SIZE = 32768
         start_time = datetime.now()
-        new_downloaded_size = 0
         with open(local_path, mode) as f:
-            async for chunk in response.content.iter_chunked(8192):
+            async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
                 f.write(chunk)
-                new_downloaded_size += len(chunk)
-                if progress_callback:
+                downloaded_size += len(chunk)
+                if progress_callback and total_size:
                     elapsed_time = (datetime.now() - start_time).total_seconds()
-                    speed = new_downloaded_size / elapsed_time if elapsed_time > 0 else 0
-                    eta = timedelta(seconds=(total_size - downloaded_size - new_downloaded_size) / speed) if speed > 0 else timedelta(0)
-                    await progress_callback(file_path, new_downloaded_size, total_size - downloaded_size, speed, eta)
-        print(f"Downloaded: {file_path}")
-
-async def download_all_files(repo_id, revision="main", progress_callback: Optional[Callable[[int, int, int, int, timedelta, dict], Coroutine[Any, Any, None]]] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None):
+                    speed = downloaded_size / elapsed_time if elapsed_time > 0 else 0
+                    remaining_size = total_size - downloaded_size
+                    eta = timedelta(seconds=remaining_size / speed) if speed > 0 else timedelta(0)
+                    status = "in_progress" if downloaded_size < total_size else "complete"
+                    await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, total_size, speed, eta, status))
+        if DEBUG >= 2: print(f"Downloaded: {file_path}")
+
+async def download_all_files(repo_id: str, revision: str = "main", progress_callback: Optional[HFRepoProgressCallback] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None):
     repo_root = get_repo_root(repo_id)
     refs_dir = repo_root / "refs"
     snapshots_dir = repo_root / "snapshots"
@@ -197,7 +214,7 @@ async def download_all_files(repo_id, revision="main", progress_callback: Option
         headers = get_auth_headers()
         async with session.get(api_url, headers=headers) as response:
             if response.status != 200:
-                raise Exception(f"Failed to fetch revision info: {response.status}")
+                raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
             revision_info = await response.json()
             commit_hash = revision_info['sha']
 
@@ -215,68 +232,32 @@ async def download_all_files(repo_id, revision="main", progress_callback: Option
         completed_files = 0
         total_bytes = sum(file["size"] for file in filtered_file_list)
         downloaded_bytes = 0
-        new_downloaded_bytes = 0
-        file_progress = {file["path"]: {"status": "not_started", "downloaded": 0, "total": file["size"]} for file in filtered_file_list}
+        file_progress: Dict[str, HFRepoFileProgressEvent] = {file["path"]: HFRepoFileProgressEvent(file["path"], 0, file["size"], 0, timedelta(0), "not_started") for file in filtered_file_list}
         start_time = datetime.now()
 
         async def download_with_progress(file_info):
-            nonlocal completed_files, downloaded_bytes, new_downloaded_bytes, file_progress
-
-            async def file_progress_callback(path, file_downloaded, file_total, speed, file_eta):
-                nonlocal downloaded_bytes, new_downloaded_bytes, file_progress
-                new_downloaded_bytes += file_downloaded - file_progress[path]['downloaded']
-                downloaded_bytes += file_downloaded - file_progress[path]['downloaded']
-                file_progress[path].update({
-                    'status': 'in_progress',
-                    'downloaded': file_downloaded,
-                    'total': file_total,
-                    'speed': speed,
-                    'eta': file_eta
-                })
+            nonlocal completed_files, downloaded_bytes, file_progress
+
+            async def file_progress_callback(event: HFRepoFileProgressEvent):
+                nonlocal downloaded_bytes, file_progress
+                downloaded_bytes += event.downloaded - file_progress[event.file_path].downloaded
+                file_progress[event.file_path] = event
                 if progress_callback:
                     elapsed_time = (datetime.now() - start_time).total_seconds()
-                    overall_speed = new_downloaded_bytes / elapsed_time if elapsed_time > 0 else 0
+                    overall_speed = downloaded_bytes / elapsed_time if elapsed_time > 0 else 0
                     overall_eta = timedelta(seconds=(total_bytes - downloaded_bytes) / overall_speed) if overall_speed > 0 else timedelta(0)
-                    await progress_callback(completed_files, total_files, new_downloaded_bytes, total_bytes, overall_eta, file_progress)
+                    await progress_callback(HFRepoProgressEvent(completed_files, total_files, downloaded_bytes, total_bytes, overall_eta, file_progress))
 
             await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
             completed_files += 1
-            file_progress[file_info["path"]]['status'] = 'complete'
+            file_progress[file_info["path"]] = HFRepoFileProgressEvent(file_info["path"], file_info["size"], file_info["size"], 0, timedelta(0), "complete")
             if progress_callback:
                 elapsed_time = (datetime.now() - start_time).total_seconds()
-                overall_speed = new_downloaded_bytes / elapsed_time if elapsed_time > 0 else 0
+                overall_speed = downloaded_bytes / elapsed_time if elapsed_time > 0 else 0
                 overall_eta = timedelta(seconds=(total_bytes - downloaded_bytes) / overall_speed) if overall_speed > 0 else timedelta(0)
-                await progress_callback(completed_files, total_files, new_downloaded_bytes, total_bytes, overall_eta, file_progress)
+                await progress_callback(HFRepoProgressEvent(completed_files, total_files, downloaded_bytes, total_bytes, overall_eta, file_progress))
 
         tasks = [download_with_progress(file_info) for file_info in filtered_file_list]
         await asyncio.gather(*tasks)
 
-async def main(repo_id, revision="main", allow_patterns=None, ignore_patterns=None):
-    async def progress_callback(completed_files, total_files, downloaded_bytes, total_bytes, overall_eta, file_progress):
-        print(f"Overall Progress: {completed_files}/{total_files} files, {downloaded_bytes}/{total_bytes} bytes")
-        print(f"Estimated time remaining: {overall_eta}")
-        print("File Progress:")
-        for file_path, progress in file_progress.items():
-            status_icon = {
-                'not_started': '⚪',
-                'in_progress': '🔵',
-                'complete': '✅'
-            }[progress['status']]
-            eta_str = str(progress.get('eta', 'N/A'))
-            print(f"{status_icon} {file_path}: {progress.get('downloaded', 0)}/{progress['total']} bytes, "
-                  f"Speed: {progress.get('speed', 0):.2f} B/s, ETA: {eta_str}")
-        print("\n")
-
-    await download_all_files(repo_id, revision, progress_callback, allow_patterns, ignore_patterns)
-
-
-if __name__ == "__main__":
-    parser = argparse.ArgumentParser(description="Download files from a Hugging Face model repository.")
-    parser.add_argument("--repo-id", help="The repository ID (e.g., 'meta-llama/Meta-Llama-3.1-8B-Instruct')")
-    parser.add_argument("--revision", default="main", help="The revision to download (branch, tag, or commit hash)")
-    parser.add_argument("--allow-patterns", nargs="*", default=DEFAULT_ALLOW_PATTERNS, help="Patterns of files to allow (e.g., '*.json' '*.safetensors')")
-    parser.add_argument("--ignore-patterns", nargs="*", default=DEFAULT_IGNORE_PATTERNS, help="Patterns of files to ignore (e.g., '.*')")
-
-    args = parser.parse_args()
-
-    asyncio.run(main(args.repo_id, args.revision, args.allow_patterns, args.ignore_patterns))
+    return snapshot_dir

+ 3 - 3
exo/inference/inference_engine.py

@@ -1,9 +1,9 @@
 import numpy as np
 
-from typing import Tuple, Optional, Callable
+from typing import Tuple, Optional, Callable, Coroutine, Any
 from abc import ABC, abstractmethod
 from .shard import Shard
-
+from exo.inference.hf_helpers import HFRepoProgressEvent
 
 class InferenceEngine(ABC):
   @abstractmethod
@@ -15,5 +15,5 @@ class InferenceEngine(ABC):
     pass
 
   @abstractmethod
-  def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]):
+  def set_progress_callback(self, progress_callback: Callable[[HFRepoProgressEvent], Coroutine[Any, Any, None]]):
     pass

+ 6 - 5
exo/inference/mlx/sharded_inference_engine.py

@@ -5,12 +5,13 @@ from .sharded_model import StatefulShardedModel
 from .sharded_utils import load_shard, get_image_from_str
 from ..shard import Shard
 from typing import Optional, Callable
+from exo.inference.hf_helpers import HFRepoProgressCallback
 
 
 class MLXDynamicShardInferenceEngine(InferenceEngine):
-  def __init__(self, on_download_progress: Callable[[int, int], None] = None):
+  def __init__(self, progress_callback: Optional[HFRepoProgressCallback] = None):
     self.shard = None
-    self.on_download_progress = on_download_progress
+    self.progress_callback = progress_callback
 
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     await self.ensure_shard(shard)
@@ -33,9 +34,9 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     if self.shard == shard:
       return
 
-    model_shard, self.tokenizer = await load_shard(shard.model_id, shard, on_download_progress=self.on_download_progress)
+    model_shard, self.tokenizer = await load_shard(shard.model_id, shard, progress_callback=self.progress_callback)
     self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
     self.shard = shard
 
-  def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]):
-    self.on_download_progress = on_download_progress
+  def set_progress_callback(self, progress_callback: HFRepoProgressCallback):
+    self.progress_callback = progress_callback

+ 7 - 54
exo/inference/mlx/sharded_utils.py

@@ -12,10 +12,7 @@ from typing import Optional, Tuple, Union, List, Callable
 from PIL import Image
 from io import BytesIO
 import base64
-import os
-import concurrent.futures
 
-from exo import DEBUG
 import mlx.core as mx
 import mlx.nn as nn
 from huggingface_hub import snapshot_download, list_repo_tree, get_paths_info
@@ -28,6 +25,8 @@ from transformers import AutoProcessor
 from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
 from mlx_lm.tuner.utils import apply_lora_layers
 
+from exo import DEBUG
+from exo.inference.hf_helpers import download_all_files, HFRepoProgressCallback
 from ..shard import Shard
 
 
@@ -164,52 +163,6 @@ def load_model_shard(
   return model
 
 
-async def get_repo_size(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None):
-  it = await asyncio.to_thread(list_repo_tree, repo_id, revision=revision, repo_type=repo_type)
-  files = list(filter_repo_objects(it, allow_patterns=allow_patterns, key=lambda f: f.path))
-  return sum(file.size for file in files if hasattr(file, "size") and file.size is not None)
-
-async def monitor_progress(dir, total_size, print_progress=False, on_progress: Callable[[int, int], None] = None):
-    while True:
-      try:
-        await asyncio.sleep(0.1)
-        current_size = sum(os.path.getsize(os.path.join(root, file))
-                            for root, _, files in os.walk(dir)
-                            for file in files)
-        progress = min(current_size / total_size * 100, 100)
-        if print_progress:
-          print(f"\rProgress: {progress:.2f}% ({current_size}/{total_size} bytes)", end="", flush=True)
-        if on_progress:
-          on_progress(current_size, total_size)
-        if progress >= 100:
-          if print_progress:
-            print("\nDownload complete!")
-          break
-      except Exception as e:
-        print(f"Error monitoring progress: {e}")
-
-async def download_repo(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None):
-    with concurrent.futures.ThreadPoolExecutor() as pool:
-        return await asyncio.get_event_loop().run_in_executor(
-            pool,
-            partial(snapshot_download, repo_id=repo_id, revision=revision, allow_patterns=allow_patterns, repo_type=repo_type)
-        )
-
-async def download_async_with_progress(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None, on_progress: Callable[[int, int], None] = None):
-  storage_folder = os.path.join(HF_HUB_CACHE, repo_folder_name(repo_id=repo_id, repo_type="model"))
-  # os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'
-  # os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
-
-  total_size = await get_repo_size(repo_id)
-
-  # Create tasks for download and progress checking
-  download_task = asyncio.create_task(download_repo(repo_id, revision=revision, allow_patterns=allow_patterns, repo_type=repo_type))
-  progress_task = asyncio.create_task(monitor_progress(storage_folder, total_size, on_progress=on_progress))
-
-  # Wait for both tasks to complete
-  result = await asyncio.gather(download_task, progress_task, return_exceptions=True)
-  return result[0]  # Return the result from download_task
-
 repo_id_safetensors_layers = {
   "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit": {
     "model.safetensors": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
@@ -313,7 +266,7 @@ def get_safetensors_allow_patterns(repo_id: str, shard: Optional[Shard] = None):
 
     return allow_patterns if len(allow_patterns) > 0 else ["*.safetensors"]
 
-async def get_model_path(path_or_hf_repo: str, shard: Optional[Shard] = None, revision: Optional[str] = None, on_download_progress: Callable[[int, int], None] = None) -> Path:
+async def get_model_path(path_or_hf_repo: str, shard: Optional[Shard] = None, revision: str = "main", progress_callback: Optional[HFRepoProgressCallback] = None) -> Path:
   """
   Ensures the model is available locally. If the path does not exist locally,
   it is downloaded from the Hugging Face Hub.
@@ -329,7 +282,7 @@ async def get_model_path(path_or_hf_repo: str, shard: Optional[Shard] = None, re
   if not model_path.exists():
     try:
       model_path = Path(
-        await download_async_with_progress(
+        await download_all_files(
           repo_id=path_or_hf_repo,
           revision=revision,
           allow_patterns=[
@@ -339,7 +292,7 @@ async def get_model_path(path_or_hf_repo: str, shard: Optional[Shard] = None, re
             "*.tiktoken",
             "*.txt",
           ] + get_safetensors_allow_patterns(path_or_hf_repo, shard),
-          on_progress=on_download_progress,
+          progress_callback=progress_callback,
         )
       )
     except RepositoryNotFoundError:
@@ -360,7 +313,7 @@ async def load_shard(
   model_config={},
   adapter_path: Optional[str] = None,
   lazy: bool = False,
-  on_download_progress: Callable[[int, int], None] = None,
+  progress_callback: Optional[HFRepoProgressCallback] = None,
 ) -> Tuple[nn.Module, TokenizerWrapper]:
   """
   Load the model and tokenizer from a given path or a huggingface repository.
@@ -383,7 +336,7 @@ async def load_shard(
    FileNotFoundError: If config file or safetensors are not found.
    ValueError: If model class or args class are not found.
   """
-  model_path = await get_model_path(path_or_hf_repo, shard, on_download_progress=on_download_progress)
+  model_path = await get_model_path(path_or_hf_repo, shard, progress_callback=progress_callback)
 
   model = load_model_shard(model_path, shard, lazy, model_config)
   if adapter_path is not None:

+ 13 - 12
exo/inference/test_inference_engine.py

@@ -1,3 +1,4 @@
+from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
 from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.shard import Shard
@@ -40,17 +41,17 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
   assert np.array_equal(next_resp_full, resp4)
 
 
-asyncio.run(
-  test_inference_engine(
-    MLXDynamicShardInferenceEngine(),
-    MLXDynamicShardInferenceEngine(),
-    "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
-  )
-)
+# asyncio.run(
+#   test_inference_engine(
+#     MLXDynamicShardInferenceEngine(),
+#     MLXDynamicShardInferenceEngine(),
+#     "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
+#   )
+# )
 
 # TODO: Need more memory or a smaller model
-# asyncio.run(test_inference_engine(
-#     TinygradDynamicShardInferenceEngine(),
-#     TinygradDynamicShardInferenceEngine(),
-#     "llama3-8b-sfr",
-# ))
+asyncio.run(test_inference_engine(
+    TinygradDynamicShardInferenceEngine(),
+    TinygradDynamicShardInferenceEngine(),
+    "llama3-8b-sfr",
+))

+ 9 - 72
exo/inference/tinygrad/inference.py

@@ -1,9 +1,7 @@
-import asyncio
 from functools import partial
 from pathlib import Path
-from typing import List, Optional, Union, Callable
+from typing import List, Optional, Union, Callable, Coroutine, Any
 import json
-import tiktoken
 from tiktoken.load import load_tiktoken_bpe
 from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
 from tinygrad.nn.state import safe_load, torch_load, load_state_dict
@@ -12,7 +10,7 @@ from tinygrad.helpers import DEBUG, tqdm, _cache_dir, fetch
 from exo.inference.shard import Shard
 from exo.inference.inference_engine import InferenceEngine
 import numpy as np
-import os
+from exo.inference.hf_helpers import HFRepoProgressCallback, HFRepoProgressEvent, download_all_files, get_repo_root
 
 MODEL_PARAMS = {
   "8B": {
@@ -46,15 +44,6 @@ MODEL_PARAMS = {
 
 
 # **** helper functions ****
-async def fetch_async(
-  url: str,
-  name: Optional[Union[Path, str]] = None,
-  subdir: Optional[str] = None,
-  allow_caching=not os.getenv("DISABLE_HTTP_CACHE"),
-) -> Path:
-  func = partial(fetch, url, name, subdir, allow_caching)
-  return await asyncio.get_event_loop().run_in_executor(None, func)
-
 
 def concat_weights(models, device=None):
   def convert(name) -> Tensor:
@@ -159,8 +148,9 @@ def prefill(model, toks, start_pos=0):
 
 
 class TinygradDynamicShardInferenceEngine(InferenceEngine):
-  def __init__(self):
+  def __init__(self, progress_callback: Optional[HFRepoProgressCallback] = None):
     self.shard = None
+    self.progress_callback = progress_callback
 
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     # TODO: we need to refactor models/llamaa to handle per-request-kv-cache. right now it's shared between requests.
@@ -199,62 +189,9 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     if self.shard == shard:
       return
 
-    model_path = Path(shard.model_id)
-    models_dir = Path(_cache_dir) / "tinygrad" / "downloads"
-    model_path = models_dir / shard.model_id
-    size = "8B"
-    if Path(model_path / "tokenizer_config.json").exists():
-      model = model_path
-    else:
-
-      if DEBUG >= 2: print(f"Downloading tinygrad model {shard.model_id}...")
-      if shard.model_id.lower().find("llama3-8b-sfr") != -1:
-        num_files = 4
-        for i in range(num_files):
-          await fetch_async(
-            f"https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/model-{(i+1):05d}-of-{num_files:05d}.safetensors",
-            f"model-{(i+1):05d}-of-{num_files:05d}.safetensors",
-            subdir=shard.model_id,
-          )
-        await fetch_async(
-          "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/config.json",
-          "config.json",
-          subdir=shard.model_id,
-        )
-        model = await fetch_async(
-          "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/raw/main/model.safetensors.index.json",
-          "model.safetensors.index.json",
-          subdir=shard.model_id,
-        )
-        await fetch_async(
-          "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/special_tokens_map.json",
-          "special_tokens_map.json",
-          subdir=shard.model_id,
-        )
-        await fetch_async(
-          "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/tokenizer.json",
-          "tokenizer.json",
-          subdir=shard.model_id,
-        )
-        await fetch_async(
-          "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/tokenizer_config.json",
-          "tokenizer_config.json",
-          subdir=shard.model_id,
-        )
-        size = "8B"
-      elif shard.model_id.lower().find("llama3-70b-sfr") != -1:
-        raise NotImplementedError("llama3-70b-sfr is not implemented for tinygrad")
-        # fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-70B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir=shard.model_id)
-        # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00001-of-00004.safetensors", "model-00001-of-00004.safetensors", subdir=shard.model_id)
-        # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00002-of-00004.safetensors", "model-00002-of-00004.safetensors", subdir=shard.model_id)
-        # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00003-of-00004.safetensors", "model-00003-of-00004.safetensors", subdir=shard.model_id)
-        # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00004-of-00004.safetensors", "model-00004-of-00004.safetensors", subdir=shard.model_id)
-        # model = fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/raw/main/model.safetensors.index.json", "model.safetensors.index.json", subdir=shard.model_id)
-        # size = "70B"
-      else:
-        raise ValueError(f"tinygrad doesnt currently support arbitrary model downloading. unsupported model: {shard.model_id}")
-
-    model = build_transformer(model_path, shard=shard, model_size=size)
+    model_path = await download_all_files(shard.model_id, progress_callback=self.progress_callback)
+    print(f"{model_path=}")
+    model = build_transformer(model_path, shard=shard, model_size="8B" if "8b" in shard.model_id else "70B" if "70b" in shard.model_id else "8B")
     from transformers import AutoTokenizer
     tokenizer = AutoTokenizer.from_pretrained(str((model_path if model_path.is_dir() else model_path.parent)))
 
@@ -262,5 +199,5 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     self.model = model
     self.tokenizer = tokenizer
 
-  def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]):
-    pass
+  def set_progress_callback(self, progress_callback: Callable[[HFRepoProgressEvent], Coroutine[Any, Any, None]]):
+    self.progress_callback = progress_callback

+ 53 - 0
extra/download_hf.py

@@ -0,0 +1,53 @@
+import argparse
+import asyncio
+from exo.inference.hf_helpers import download_all_files, HFRepoProgressEvent, HFRepoFileProgressEvent
+
+DEFAULT_ALLOW_PATTERNS = [
+    "*.json",
+    "*.py",
+    "tokenizer.model",
+    "*.tiktoken",
+    "*.txt",
+    "*.safetensors",
+]
+# Always ignore `.git` and `.cache/huggingface` folders in commits
+DEFAULT_IGNORE_PATTERNS = [
+    ".git",
+    ".git/*",
+    "*/.git",
+    "**/.git/**",
+    ".cache/huggingface",
+    ".cache/huggingface/*",
+    "*/.cache/huggingface",
+    "**/.cache/huggingface/**",
+]
+
+async def main(repo_id, revision="main", allow_patterns=None, ignore_patterns=None):
+    async def progress_callback(event: HFRepoProgressEvent):
+        print(f"Overall Progress: {event.completed_files}/{event.total_files} files, {event.downloaded_bytes}/{event.total_bytes} bytes")
+        print(f"Estimated time remaining: {event.overall_eta}")
+        print("File Progress:")
+        for file_path, progress in event.file_progress.items():
+            status_icon = {
+                'not_started': '⚪',
+                'in_progress': '🔵',
+                'complete': '✅'
+            }[progress.status]
+            eta_str = str(progress.eta)
+            print(f"{status_icon} {file_path}: {progress.downloaded}/{progress.total} bytes, "
+                  f"Speed: {progress.speed:.2f} B/s, ETA: {eta_str}")
+        print("\n")
+
+    await download_all_files(repo_id, revision, progress_callback, allow_patterns, ignore_patterns)
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description="Download files from a Hugging Face model repository.")
+    parser.add_argument("--repo-id", required=True, help="The repository ID (e.g., 'meta-llama/Meta-Llama-3.1-8B-Instruct')")
+    parser.add_argument("--revision", default="main", help="The revision to download (branch, tag, or commit hash)")
+    parser.add_argument("--allow-patterns", nargs="*", default=None, help="Patterns of files to allow (e.g., '*.json' '*.safetensors')")
+    parser.add_argument("--ignore-patterns", nargs="*", default=None, help="Patterns of files to ignore (e.g., '.*')")
+
+    args = parser.parse_args()
+
+    asyncio.run(main(args.repo_id, args.revision, args.allow_patterns, args.ignore_patterns))

+ 1 - 1
main.py

@@ -60,7 +60,7 @@ node.on_token.register("main_log").on_next(lambda _, tokens, __: print(inference
 if args.prometheus_client_port:
     from exo.stats.metrics import start_metrics_server
     start_metrics_server(node, args.prometheus_client_port)
-inference_engine.set_on_download_progress(lambda current, total: asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "current": current, "total": total}))))
+inference_engine.set_progress_callback(lambda event: asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "current": event.downloaded_bytes, "total": event.total_bytes}))))
 
 async def shutdown(signal, loop):
     """Gracefully shutdown the server and close the asyncio loop."""

+ 2 - 0
setup.py

@@ -6,6 +6,7 @@ from setuptools import find_packages, setup
 install_requires = [
     "aiohttp==3.9.5",
     "aiohttp_cors==0.7.0",
+    "aiofiles==24.1.0",
     "blobfile==2.1.1",
     "grpcio==1.64.1",
     "grpcio-tools==1.64.1",
@@ -21,6 +22,7 @@ install_requires = [
     "requests==2.32.3",
     "rich==13.7.1",
     "safetensors==0.4.3",
+    "tenacity==9.0.0",
     "tiktoken==0.7.0",
     "tokenizers==0.19.1",
     "tqdm==4.66.4",