Browse Source

separate hf_helpers, make extra dir with download_hf script, unify downloading so tinygrad uses the same method as mlx and interoperable model formats

Alex Cheema 9 months ago
parent
commit
545a486ed3

+ 2 - 19
exo/api/chatgpt_api.py

@@ -25,11 +25,11 @@ shard_mappings = {
   },
   },
   "llama-3-8b": {
   "llama-3-8b": {
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-8b-sfr", start_layer=0, end_layer=0, n_layers=32),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct", start_layer=0, end_layer=0, n_layers=32),
   },
   },
   "llama-3-70b": {
   "llama-3-70b": {
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-70b-sfr", start_layer=0, end_layer=0, n_layers=80),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct", start_layer=0, end_layer=0, n_layers=80),
   },
   },
   ### mistral
   ### mistral
   "mistral-nemo": {
   "mistral-nemo": {
@@ -76,14 +76,6 @@ class ChatCompletionRequest:
         }
         }
 
 
 
 
-def resolve_tinygrad_tokenizer(model_id: str):
-  if model_id == "llama3-8b-sfr":
-    return AutoTokenizer.from_pretrained("TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R")
-  elif model_id == "llama3-70b-sfr":
-    return AutoTokenizer.from_pretrained("TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R")
-  else:
-    raise ValueError(f"tinygrad doesnt currently support arbitrary model downloading. unsupported model: {model_id}")
-
 
 
 async def resolve_tokenizer(model_id: str):
 async def resolve_tokenizer(model_id: str):
   try:
   try:
@@ -111,15 +103,6 @@ async def resolve_tokenizer(model_id: str):
 
 
     if DEBUG >= 2: print(traceback.format_exc())
     if DEBUG >= 2: print(traceback.format_exc())
 
 
-  try:
-    if DEBUG >= 2: print(f"Trying tinygrad tokenizer for {model_id}")
-    return resolve_tinygrad_tokenizer(model_id)
-  except Exception as e:
-    if DEBUG >= 2: print(f"Failed again to load tokenizer for {model_id}. Falling back to mlx tokenizer. Error: {e}")
-    import traceback
-
-    if DEBUG >= 2: print(traceback.format_exc())
-
   if DEBUG >= 2: print(f"Trying mlx tokenizer for {model_id}")
   if DEBUG >= 2: print(f"Trying mlx tokenizer for {model_id}")
   from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
   from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
 
 

+ 95 - 114
hf_async.py → exo/inference/hf_helpers.py

@@ -1,36 +1,17 @@
 import asyncio
 import asyncio
 import aiohttp
 import aiohttp
 import os
 import os
-import argparse
 from urllib.parse import urljoin
 from urllib.parse import urljoin
-from typing import Callable, Optional, Coroutine, Any
+from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
 from fnmatch import fnmatch
 from fnmatch import fnmatch
 from pathlib import Path
 from pathlib import Path
-from typing import Generator, Iterable, List, TypeVar, Union
+from typing import Generator, Iterable, TypeVar, TypedDict
+from dataclasses import dataclass
+from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
+from exo.helpers import DEBUG
 
 
 T = TypeVar("T")
 T = TypeVar("T")
-
-DEFAULT_ALLOW_PATTERNS = [
-    "*.json",
-    "*.py",
-    "tokenizer.model",
-    "*.tiktoken",
-    "*.txt",
-    "*.safetensors",
-]
-# Always ignore `.git` and `.cache/huggingface` folders in commits
-DEFAULT_IGNORE_PATTERNS = [
-    ".git",
-    ".git/*",
-    "*/.git",
-    "**/.git/**",
-    ".cache/huggingface",
-    ".cache/huggingface/*",
-    "*/.cache/huggingface",
-    "**/.cache/huggingface/**",
-]
-
 def filter_repo_objects(
 def filter_repo_objects(
     items: Iterable[T],
     items: Iterable[T],
     *,
     *,
@@ -117,7 +98,35 @@ async def fetch_file_list(session, repo_id, revision, path=""):
         else:
         else:
             raise Exception(f"Failed to fetch file list: {response.status}")
             raise Exception(f"Failed to fetch file list: {response.status}")
 
 
-async def download_file(session, repo_id, revision, file_path, save_directory, progress_callback: Optional[Callable[[str, int, int, float, timedelta], Coroutine[Any, Any, None]]] = None):
+
+@dataclass
+class HFRepoFileProgressEvent:
+    file_path: str
+    downloaded: int
+    total: int
+    speed: float
+    eta: timedelta
+    status: Literal["not_started", "in_progress", "complete"]
+
+@dataclass
+class HFRepoProgressEvent:
+    completed_files: int
+    total_files: int
+    downloaded_bytes: int
+    total_bytes: int
+    overall_eta: timedelta
+    file_progress: Dict[str, HFRepoFileProgressEvent]
+
+HFRepoFileProgressCallback = Callable[[HFRepoFileProgressEvent], Coroutine[Any, Any, None]]
+HFRepoProgressCallback = Callable[[HFRepoProgressEvent], Coroutine[Any, Any, None]]
+
+@retry(
+    stop=stop_after_attempt(5),
+    wait=wait_exponential(multiplier=1, min=4, max=60),
+    retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError, aiohttp.ClientResponseError)),
+    reraise=True
+)
+async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[HFRepoFileProgressCallback] = None, use_range_request: bool = True):
     base_url = f"https://huggingface.co/{repo_id}/resolve/{revision}/"
     base_url = f"https://huggingface.co/{repo_id}/resolve/{revision}/"
     url = urljoin(base_url, file_path)
     url = urljoin(base_url, file_path)
     local_path = os.path.join(save_directory, file_path)
     local_path = os.path.join(save_directory, file_path)
@@ -125,64 +134,72 @@ async def download_file(session, repo_id, revision, file_path, save_directory, p
     os.makedirs(os.path.dirname(local_path), exist_ok=True)
     os.makedirs(os.path.dirname(local_path), exist_ok=True)
 
 
     # Check if file already exists and get its size
     # Check if file already exists and get its size
-    if os.path.exists(local_path):
-        local_file_size = os.path.getsize(local_path)
-    else:
-        local_file_size = 0
+    local_file_size = os.path.getsize(local_path) if os.path.exists(local_path) else 0
+
+    headers = get_auth_headers()
+    if use_range_request:
+        headers["Range"] = f"bytes={local_file_size}-"
 
 
-    headers = {"Range": f"bytes={local_file_size}-"}
-    headers.update(get_auth_headers())
     async with session.get(url, headers=headers) as response:
     async with session.get(url, headers=headers) as response:
+        total_size = int(response.headers.get('Content-Length', 0))
+        downloaded_size = local_file_size
+        mode = 'ab' if use_range_request else 'wb'
+        if downloaded_size == total_size:
+            if DEBUG >= 2: print(f"File already downloaded: {file_path}")
+            if progress_callback:
+                await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, total_size, 0, timedelta(0), "complete"))
+            return
+
         if response.status == 200:
         if response.status == 200:
-            # File doesn't support range requests, start from beginning
+            # File doesn't support range requests or we're not using them, start from beginning
             mode = 'wb'
             mode = 'wb'
-            total_size = int(response.headers.get('Content-Length', 0))
             downloaded_size = 0
             downloaded_size = 0
         elif response.status == 206:
         elif response.status == 206:
             # Partial content, resume download
             # Partial content, resume download
-            mode = 'ab'
-            content_range = response.headers.get('Content-Range')
-            total_size = int(content_range.split('/')[-1])
-            downloaded_size = local_file_size
+            content_range = response.headers.get('Content-Range', '')
+            try:
+                total_size = int(content_range.split('/')[-1])
+            except ValueError:
+                if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
+                return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
         elif response.status == 416:
         elif response.status == 416:
             # Range not satisfiable, get the actual file size
             # Range not satisfiable, get the actual file size
-            if response.headers.get('Content-Type', '').startswith('text/html'):
-                content = await response.text()
-                print(f"Response content (HTML):\n{content}")
-            else:
-                print(response)
-            print("Return header: ", response.headers)
-            print("Return header: ", response.headers.get('Content-Range').split('/')[-1])
-            total_size = int(response.headers.get('Content-Range', '').split('/')[-1])
-            if local_file_size == total_size:
-                print(f"File already fully downloaded: {file_path}")
-                return
-            else:
-                # Start the download from the beginning
-                mode = 'wb'
-                downloaded_size = 0
+            content_range = response.headers.get('Content-Range', '')
+            try:
+                total_size = int(content_range.split('/')[-1])
+                if downloaded_size == total_size:
+                    if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}")
+                    if progress_callback:
+                        await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, total_size, 0, timedelta(0), "complete"))
+                    return
+            except ValueError:
+                if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
+                return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
         else:
         else:
-            print(f"Failed to download {file_path}: {response.status}")
-            return
+            raise aiohttp.ClientResponseError(response.request_info, response.history, status=response.status, message=f"Failed to download {file_path}: {response.status}")
 
 
         if downloaded_size == total_size:
         if downloaded_size == total_size:
             print(f"File already downloaded: {file_path}")
             print(f"File already downloaded: {file_path}")
+            if progress_callback:
+                await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, total_size, 0, timedelta(0), "complete"))
             return
             return
 
 
+        DOWNLOAD_CHUNK_SIZE = 32768
         start_time = datetime.now()
         start_time = datetime.now()
-        new_downloaded_size = 0
         with open(local_path, mode) as f:
         with open(local_path, mode) as f:
-            async for chunk in response.content.iter_chunked(8192):
+            async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
                 f.write(chunk)
                 f.write(chunk)
-                new_downloaded_size += len(chunk)
-                if progress_callback:
+                downloaded_size += len(chunk)
+                if progress_callback and total_size:
                     elapsed_time = (datetime.now() - start_time).total_seconds()
                     elapsed_time = (datetime.now() - start_time).total_seconds()
-                    speed = new_downloaded_size / elapsed_time if elapsed_time > 0 else 0
-                    eta = timedelta(seconds=(total_size - downloaded_size - new_downloaded_size) / speed) if speed > 0 else timedelta(0)
-                    await progress_callback(file_path, new_downloaded_size, total_size - downloaded_size, speed, eta)
-        print(f"Downloaded: {file_path}")
-
-async def download_all_files(repo_id, revision="main", progress_callback: Optional[Callable[[int, int, int, int, timedelta, dict], Coroutine[Any, Any, None]]] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None):
+                    speed = downloaded_size / elapsed_time if elapsed_time > 0 else 0
+                    remaining_size = total_size - downloaded_size
+                    eta = timedelta(seconds=remaining_size / speed) if speed > 0 else timedelta(0)
+                    status = "in_progress" if downloaded_size < total_size else "complete"
+                    await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, total_size, speed, eta, status))
+        if DEBUG >= 2: print(f"Downloaded: {file_path}")
+
+async def download_all_files(repo_id: str, revision: str = "main", progress_callback: Optional[HFRepoProgressCallback] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None):
     repo_root = get_repo_root(repo_id)
     repo_root = get_repo_root(repo_id)
     refs_dir = repo_root / "refs"
     refs_dir = repo_root / "refs"
     snapshots_dir = repo_root / "snapshots"
     snapshots_dir = repo_root / "snapshots"
@@ -197,7 +214,7 @@ async def download_all_files(repo_id, revision="main", progress_callback: Option
         headers = get_auth_headers()
         headers = get_auth_headers()
         async with session.get(api_url, headers=headers) as response:
         async with session.get(api_url, headers=headers) as response:
             if response.status != 200:
             if response.status != 200:
-                raise Exception(f"Failed to fetch revision info: {response.status}")
+                raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
             revision_info = await response.json()
             revision_info = await response.json()
             commit_hash = revision_info['sha']
             commit_hash = revision_info['sha']
 
 
@@ -215,68 +232,32 @@ async def download_all_files(repo_id, revision="main", progress_callback: Option
         completed_files = 0
         completed_files = 0
         total_bytes = sum(file["size"] for file in filtered_file_list)
         total_bytes = sum(file["size"] for file in filtered_file_list)
         downloaded_bytes = 0
         downloaded_bytes = 0
-        new_downloaded_bytes = 0
-        file_progress = {file["path"]: {"status": "not_started", "downloaded": 0, "total": file["size"]} for file in filtered_file_list}
+        file_progress: Dict[str, HFRepoFileProgressEvent] = {file["path"]: HFRepoFileProgressEvent(file["path"], 0, file["size"], 0, timedelta(0), "not_started") for file in filtered_file_list}
         start_time = datetime.now()
         start_time = datetime.now()
 
 
         async def download_with_progress(file_info):
         async def download_with_progress(file_info):
-            nonlocal completed_files, downloaded_bytes, new_downloaded_bytes, file_progress
-
-            async def file_progress_callback(path, file_downloaded, file_total, speed, file_eta):
-                nonlocal downloaded_bytes, new_downloaded_bytes, file_progress
-                new_downloaded_bytes += file_downloaded - file_progress[path]['downloaded']
-                downloaded_bytes += file_downloaded - file_progress[path]['downloaded']
-                file_progress[path].update({
-                    'status': 'in_progress',
-                    'downloaded': file_downloaded,
-                    'total': file_total,
-                    'speed': speed,
-                    'eta': file_eta
-                })
+            nonlocal completed_files, downloaded_bytes, file_progress
+
+            async def file_progress_callback(event: HFRepoFileProgressEvent):
+                nonlocal downloaded_bytes, file_progress
+                downloaded_bytes += event.downloaded - file_progress[event.file_path].downloaded
+                file_progress[event.file_path] = event
                 if progress_callback:
                 if progress_callback:
                     elapsed_time = (datetime.now() - start_time).total_seconds()
                     elapsed_time = (datetime.now() - start_time).total_seconds()
-                    overall_speed = new_downloaded_bytes / elapsed_time if elapsed_time > 0 else 0
+                    overall_speed = downloaded_bytes / elapsed_time if elapsed_time > 0 else 0
                     overall_eta = timedelta(seconds=(total_bytes - downloaded_bytes) / overall_speed) if overall_speed > 0 else timedelta(0)
                     overall_eta = timedelta(seconds=(total_bytes - downloaded_bytes) / overall_speed) if overall_speed > 0 else timedelta(0)
-                    await progress_callback(completed_files, total_files, new_downloaded_bytes, total_bytes, overall_eta, file_progress)
+                    await progress_callback(HFRepoProgressEvent(completed_files, total_files, downloaded_bytes, total_bytes, overall_eta, file_progress))
 
 
             await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
             await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
             completed_files += 1
             completed_files += 1
-            file_progress[file_info["path"]]['status'] = 'complete'
+            file_progress[file_info["path"]] = HFRepoFileProgressEvent(file_info["path"], file_info["size"], file_info["size"], 0, timedelta(0), "complete")
             if progress_callback:
             if progress_callback:
                 elapsed_time = (datetime.now() - start_time).total_seconds()
                 elapsed_time = (datetime.now() - start_time).total_seconds()
-                overall_speed = new_downloaded_bytes / elapsed_time if elapsed_time > 0 else 0
+                overall_speed = downloaded_bytes / elapsed_time if elapsed_time > 0 else 0
                 overall_eta = timedelta(seconds=(total_bytes - downloaded_bytes) / overall_speed) if overall_speed > 0 else timedelta(0)
                 overall_eta = timedelta(seconds=(total_bytes - downloaded_bytes) / overall_speed) if overall_speed > 0 else timedelta(0)
-                await progress_callback(completed_files, total_files, new_downloaded_bytes, total_bytes, overall_eta, file_progress)
+                await progress_callback(HFRepoProgressEvent(completed_files, total_files, downloaded_bytes, total_bytes, overall_eta, file_progress))
 
 
         tasks = [download_with_progress(file_info) for file_info in filtered_file_list]
         tasks = [download_with_progress(file_info) for file_info in filtered_file_list]
         await asyncio.gather(*tasks)
         await asyncio.gather(*tasks)
 
 
-async def main(repo_id, revision="main", allow_patterns=None, ignore_patterns=None):
-    async def progress_callback(completed_files, total_files, downloaded_bytes, total_bytes, overall_eta, file_progress):
-        print(f"Overall Progress: {completed_files}/{total_files} files, {downloaded_bytes}/{total_bytes} bytes")
-        print(f"Estimated time remaining: {overall_eta}")
-        print("File Progress:")
-        for file_path, progress in file_progress.items():
-            status_icon = {
-                'not_started': '⚪',
-                'in_progress': '🔵',
-                'complete': '✅'
-            }[progress['status']]
-            eta_str = str(progress.get('eta', 'N/A'))
-            print(f"{status_icon} {file_path}: {progress.get('downloaded', 0)}/{progress['total']} bytes, "
-                  f"Speed: {progress.get('speed', 0):.2f} B/s, ETA: {eta_str}")
-        print("\n")
-
-    await download_all_files(repo_id, revision, progress_callback, allow_patterns, ignore_patterns)
-
-
-if __name__ == "__main__":
-    parser = argparse.ArgumentParser(description="Download files from a Hugging Face model repository.")
-    parser.add_argument("--repo-id", help="The repository ID (e.g., 'meta-llama/Meta-Llama-3.1-8B-Instruct')")
-    parser.add_argument("--revision", default="main", help="The revision to download (branch, tag, or commit hash)")
-    parser.add_argument("--allow-patterns", nargs="*", default=DEFAULT_ALLOW_PATTERNS, help="Patterns of files to allow (e.g., '*.json' '*.safetensors')")
-    parser.add_argument("--ignore-patterns", nargs="*", default=DEFAULT_IGNORE_PATTERNS, help="Patterns of files to ignore (e.g., '.*')")
-
-    args = parser.parse_args()
-
-    asyncio.run(main(args.repo_id, args.revision, args.allow_patterns, args.ignore_patterns))
+    return snapshot_dir

+ 3 - 3
exo/inference/inference_engine.py

@@ -1,9 +1,9 @@
 import numpy as np
 import numpy as np
 
 
-from typing import Tuple, Optional, Callable
+from typing import Tuple, Optional, Callable, Coroutine, Any
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 from .shard import Shard
 from .shard import Shard
-
+from exo.inference.hf_helpers import HFRepoProgressEvent
 
 
 class InferenceEngine(ABC):
 class InferenceEngine(ABC):
   @abstractmethod
   @abstractmethod
@@ -15,5 +15,5 @@ class InferenceEngine(ABC):
     pass
     pass
 
 
   @abstractmethod
   @abstractmethod
-  def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]):
+  def set_progress_callback(self, progress_callback: Callable[[HFRepoProgressEvent], Coroutine[Any, Any, None]]):
     pass
     pass

+ 6 - 5
exo/inference/mlx/sharded_inference_engine.py

@@ -5,12 +5,13 @@ from .sharded_model import StatefulShardedModel
 from .sharded_utils import load_shard, get_image_from_str
 from .sharded_utils import load_shard, get_image_from_str
 from ..shard import Shard
 from ..shard import Shard
 from typing import Optional, Callable
 from typing import Optional, Callable
+from exo.inference.hf_helpers import HFRepoProgressCallback
 
 
 
 
 class MLXDynamicShardInferenceEngine(InferenceEngine):
 class MLXDynamicShardInferenceEngine(InferenceEngine):
-  def __init__(self, on_download_progress: Callable[[int, int], None] = None):
+  def __init__(self, progress_callback: Optional[HFRepoProgressCallback] = None):
     self.shard = None
     self.shard = None
-    self.on_download_progress = on_download_progress
+    self.progress_callback = progress_callback
 
 
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
@@ -33,9 +34,9 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     if self.shard == shard:
     if self.shard == shard:
       return
       return
 
 
-    model_shard, self.tokenizer = await load_shard(shard.model_id, shard, on_download_progress=self.on_download_progress)
+    model_shard, self.tokenizer = await load_shard(shard.model_id, shard, progress_callback=self.progress_callback)
     self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
     self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
     self.shard = shard
     self.shard = shard
 
 
-  def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]):
-    self.on_download_progress = on_download_progress
+  def set_progress_callback(self, progress_callback: HFRepoProgressCallback):
+    self.progress_callback = progress_callback

+ 7 - 54
exo/inference/mlx/sharded_utils.py

@@ -12,10 +12,7 @@ from typing import Optional, Tuple, Union, List, Callable
 from PIL import Image
 from PIL import Image
 from io import BytesIO
 from io import BytesIO
 import base64
 import base64
-import os
-import concurrent.futures
 
 
-from exo import DEBUG
 import mlx.core as mx
 import mlx.core as mx
 import mlx.nn as nn
 import mlx.nn as nn
 from huggingface_hub import snapshot_download, list_repo_tree, get_paths_info
 from huggingface_hub import snapshot_download, list_repo_tree, get_paths_info
@@ -28,6 +25,8 @@ from transformers import AutoProcessor
 from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
 from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
 from mlx_lm.tuner.utils import apply_lora_layers
 from mlx_lm.tuner.utils import apply_lora_layers
 
 
+from exo import DEBUG
+from exo.inference.hf_helpers import download_all_files, HFRepoProgressCallback
 from ..shard import Shard
 from ..shard import Shard
 
 
 
 
@@ -164,52 +163,6 @@ def load_model_shard(
   return model
   return model
 
 
 
 
-async def get_repo_size(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None):
-  it = await asyncio.to_thread(list_repo_tree, repo_id, revision=revision, repo_type=repo_type)
-  files = list(filter_repo_objects(it, allow_patterns=allow_patterns, key=lambda f: f.path))
-  return sum(file.size for file in files if hasattr(file, "size") and file.size is not None)
-
-async def monitor_progress(dir, total_size, print_progress=False, on_progress: Callable[[int, int], None] = None):
-    while True:
-      try:
-        await asyncio.sleep(0.1)
-        current_size = sum(os.path.getsize(os.path.join(root, file))
-                            for root, _, files in os.walk(dir)
-                            for file in files)
-        progress = min(current_size / total_size * 100, 100)
-        if print_progress:
-          print(f"\rProgress: {progress:.2f}% ({current_size}/{total_size} bytes)", end="", flush=True)
-        if on_progress:
-          on_progress(current_size, total_size)
-        if progress >= 100:
-          if print_progress:
-            print("\nDownload complete!")
-          break
-      except Exception as e:
-        print(f"Error monitoring progress: {e}")
-
-async def download_repo(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None):
-    with concurrent.futures.ThreadPoolExecutor() as pool:
-        return await asyncio.get_event_loop().run_in_executor(
-            pool,
-            partial(snapshot_download, repo_id=repo_id, revision=revision, allow_patterns=allow_patterns, repo_type=repo_type)
-        )
-
-async def download_async_with_progress(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None, on_progress: Callable[[int, int], None] = None):
-  storage_folder = os.path.join(HF_HUB_CACHE, repo_folder_name(repo_id=repo_id, repo_type="model"))
-  # os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'
-  # os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
-
-  total_size = await get_repo_size(repo_id)
-
-  # Create tasks for download and progress checking
-  download_task = asyncio.create_task(download_repo(repo_id, revision=revision, allow_patterns=allow_patterns, repo_type=repo_type))
-  progress_task = asyncio.create_task(monitor_progress(storage_folder, total_size, on_progress=on_progress))
-
-  # Wait for both tasks to complete
-  result = await asyncio.gather(download_task, progress_task, return_exceptions=True)
-  return result[0]  # Return the result from download_task
-
 repo_id_safetensors_layers = {
 repo_id_safetensors_layers = {
   "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit": {
   "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit": {
     "model.safetensors": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
     "model.safetensors": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
@@ -313,7 +266,7 @@ def get_safetensors_allow_patterns(repo_id: str, shard: Optional[Shard] = None):
 
 
     return allow_patterns if len(allow_patterns) > 0 else ["*.safetensors"]
     return allow_patterns if len(allow_patterns) > 0 else ["*.safetensors"]
 
 
-async def get_model_path(path_or_hf_repo: str, shard: Optional[Shard] = None, revision: Optional[str] = None, on_download_progress: Callable[[int, int], None] = None) -> Path:
+async def get_model_path(path_or_hf_repo: str, shard: Optional[Shard] = None, revision: str = "main", progress_callback: Optional[HFRepoProgressCallback] = None) -> Path:
   """
   """
   Ensures the model is available locally. If the path does not exist locally,
   Ensures the model is available locally. If the path does not exist locally,
   it is downloaded from the Hugging Face Hub.
   it is downloaded from the Hugging Face Hub.
@@ -329,7 +282,7 @@ async def get_model_path(path_or_hf_repo: str, shard: Optional[Shard] = None, re
   if not model_path.exists():
   if not model_path.exists():
     try:
     try:
       model_path = Path(
       model_path = Path(
-        await download_async_with_progress(
+        await download_all_files(
           repo_id=path_or_hf_repo,
           repo_id=path_or_hf_repo,
           revision=revision,
           revision=revision,
           allow_patterns=[
           allow_patterns=[
@@ -339,7 +292,7 @@ async def get_model_path(path_or_hf_repo: str, shard: Optional[Shard] = None, re
             "*.tiktoken",
             "*.tiktoken",
             "*.txt",
             "*.txt",
           ] + get_safetensors_allow_patterns(path_or_hf_repo, shard),
           ] + get_safetensors_allow_patterns(path_or_hf_repo, shard),
-          on_progress=on_download_progress,
+          progress_callback=progress_callback,
         )
         )
       )
       )
     except RepositoryNotFoundError:
     except RepositoryNotFoundError:
@@ -360,7 +313,7 @@ async def load_shard(
   model_config={},
   model_config={},
   adapter_path: Optional[str] = None,
   adapter_path: Optional[str] = None,
   lazy: bool = False,
   lazy: bool = False,
-  on_download_progress: Callable[[int, int], None] = None,
+  progress_callback: Optional[HFRepoProgressCallback] = None,
 ) -> Tuple[nn.Module, TokenizerWrapper]:
 ) -> Tuple[nn.Module, TokenizerWrapper]:
   """
   """
   Load the model and tokenizer from a given path or a huggingface repository.
   Load the model and tokenizer from a given path or a huggingface repository.
@@ -383,7 +336,7 @@ async def load_shard(
    FileNotFoundError: If config file or safetensors are not found.
    FileNotFoundError: If config file or safetensors are not found.
    ValueError: If model class or args class are not found.
    ValueError: If model class or args class are not found.
   """
   """
-  model_path = await get_model_path(path_or_hf_repo, shard, on_download_progress=on_download_progress)
+  model_path = await get_model_path(path_or_hf_repo, shard, progress_callback=progress_callback)
 
 
   model = load_model_shard(model_path, shard, lazy, model_config)
   model = load_model_shard(model_path, shard, lazy, model_config)
   if adapter_path is not None:
   if adapter_path is not None:

+ 13 - 12
exo/inference/test_inference_engine.py

@@ -1,3 +1,4 @@
+from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
 from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
 from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
@@ -40,17 +41,17 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
   assert np.array_equal(next_resp_full, resp4)
   assert np.array_equal(next_resp_full, resp4)
 
 
 
 
-asyncio.run(
-  test_inference_engine(
-    MLXDynamicShardInferenceEngine(),
-    MLXDynamicShardInferenceEngine(),
-    "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
-  )
-)
+# asyncio.run(
+#   test_inference_engine(
+#     MLXDynamicShardInferenceEngine(),
+#     MLXDynamicShardInferenceEngine(),
+#     "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
+#   )
+# )
 
 
 # TODO: Need more memory or a smaller model
 # TODO: Need more memory or a smaller model
-# asyncio.run(test_inference_engine(
-#     TinygradDynamicShardInferenceEngine(),
-#     TinygradDynamicShardInferenceEngine(),
-#     "llama3-8b-sfr",
-# ))
+asyncio.run(test_inference_engine(
+    TinygradDynamicShardInferenceEngine(),
+    TinygradDynamicShardInferenceEngine(),
+    "llama3-8b-sfr",
+))

+ 9 - 72
exo/inference/tinygrad/inference.py

@@ -1,9 +1,7 @@
-import asyncio
 from functools import partial
 from functools import partial
 from pathlib import Path
 from pathlib import Path
-from typing import List, Optional, Union, Callable
+from typing import List, Optional, Union, Callable, Coroutine, Any
 import json
 import json
-import tiktoken
 from tiktoken.load import load_tiktoken_bpe
 from tiktoken.load import load_tiktoken_bpe
 from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
 from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
 from tinygrad.nn.state import safe_load, torch_load, load_state_dict
 from tinygrad.nn.state import safe_load, torch_load, load_state_dict
@@ -12,7 +10,7 @@ from tinygrad.helpers import DEBUG, tqdm, _cache_dir, fetch
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.inference_engine import InferenceEngine
 import numpy as np
 import numpy as np
-import os
+from exo.inference.hf_helpers import HFRepoProgressCallback, HFRepoProgressEvent, download_all_files, get_repo_root
 
 
 MODEL_PARAMS = {
 MODEL_PARAMS = {
   "8B": {
   "8B": {
@@ -46,15 +44,6 @@ MODEL_PARAMS = {
 
 
 
 
 # **** helper functions ****
 # **** helper functions ****
-async def fetch_async(
-  url: str,
-  name: Optional[Union[Path, str]] = None,
-  subdir: Optional[str] = None,
-  allow_caching=not os.getenv("DISABLE_HTTP_CACHE"),
-) -> Path:
-  func = partial(fetch, url, name, subdir, allow_caching)
-  return await asyncio.get_event_loop().run_in_executor(None, func)
-
 
 
 def concat_weights(models, device=None):
 def concat_weights(models, device=None):
   def convert(name) -> Tensor:
   def convert(name) -> Tensor:
@@ -159,8 +148,9 @@ def prefill(model, toks, start_pos=0):
 
 
 
 
 class TinygradDynamicShardInferenceEngine(InferenceEngine):
 class TinygradDynamicShardInferenceEngine(InferenceEngine):
-  def __init__(self):
+  def __init__(self, progress_callback: Optional[HFRepoProgressCallback] = None):
     self.shard = None
     self.shard = None
+    self.progress_callback = progress_callback
 
 
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     # TODO: we need to refactor models/llamaa to handle per-request-kv-cache. right now it's shared between requests.
     # TODO: we need to refactor models/llamaa to handle per-request-kv-cache. right now it's shared between requests.
@@ -199,62 +189,9 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     if self.shard == shard:
     if self.shard == shard:
       return
       return
 
 
-    model_path = Path(shard.model_id)
-    models_dir = Path(_cache_dir) / "tinygrad" / "downloads"
-    model_path = models_dir / shard.model_id
-    size = "8B"
-    if Path(model_path / "tokenizer_config.json").exists():
-      model = model_path
-    else:
-
-      if DEBUG >= 2: print(f"Downloading tinygrad model {shard.model_id}...")
-      if shard.model_id.lower().find("llama3-8b-sfr") != -1:
-        num_files = 4
-        for i in range(num_files):
-          await fetch_async(
-            f"https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/model-{(i+1):05d}-of-{num_files:05d}.safetensors",
-            f"model-{(i+1):05d}-of-{num_files:05d}.safetensors",
-            subdir=shard.model_id,
-          )
-        await fetch_async(
-          "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/config.json",
-          "config.json",
-          subdir=shard.model_id,
-        )
-        model = await fetch_async(
-          "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/raw/main/model.safetensors.index.json",
-          "model.safetensors.index.json",
-          subdir=shard.model_id,
-        )
-        await fetch_async(
-          "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/special_tokens_map.json",
-          "special_tokens_map.json",
-          subdir=shard.model_id,
-        )
-        await fetch_async(
-          "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/tokenizer.json",
-          "tokenizer.json",
-          subdir=shard.model_id,
-        )
-        await fetch_async(
-          "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/tokenizer_config.json",
-          "tokenizer_config.json",
-          subdir=shard.model_id,
-        )
-        size = "8B"
-      elif shard.model_id.lower().find("llama3-70b-sfr") != -1:
-        raise NotImplementedError("llama3-70b-sfr is not implemented for tinygrad")
-        # fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-70B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir=shard.model_id)
-        # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00001-of-00004.safetensors", "model-00001-of-00004.safetensors", subdir=shard.model_id)
-        # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00002-of-00004.safetensors", "model-00002-of-00004.safetensors", subdir=shard.model_id)
-        # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00003-of-00004.safetensors", "model-00003-of-00004.safetensors", subdir=shard.model_id)
-        # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00004-of-00004.safetensors", "model-00004-of-00004.safetensors", subdir=shard.model_id)
-        # model = fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/raw/main/model.safetensors.index.json", "model.safetensors.index.json", subdir=shard.model_id)
-        # size = "70B"
-      else:
-        raise ValueError(f"tinygrad doesnt currently support arbitrary model downloading. unsupported model: {shard.model_id}")
-
-    model = build_transformer(model_path, shard=shard, model_size=size)
+    model_path = await download_all_files(shard.model_id, progress_callback=self.progress_callback)
+    print(f"{model_path=}")
+    model = build_transformer(model_path, shard=shard, model_size="8B" if "8b" in shard.model_id else "70B" if "70b" in shard.model_id else "8B")
     from transformers import AutoTokenizer
     from transformers import AutoTokenizer
     tokenizer = AutoTokenizer.from_pretrained(str((model_path if model_path.is_dir() else model_path.parent)))
     tokenizer = AutoTokenizer.from_pretrained(str((model_path if model_path.is_dir() else model_path.parent)))
 
 
@@ -262,5 +199,5 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     self.model = model
     self.model = model
     self.tokenizer = tokenizer
     self.tokenizer = tokenizer
 
 
-  def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]):
-    pass
+  def set_progress_callback(self, progress_callback: Callable[[HFRepoProgressEvent], Coroutine[Any, Any, None]]):
+    self.progress_callback = progress_callback

+ 53 - 0
extra/download_hf.py

@@ -0,0 +1,53 @@
+import argparse
+import asyncio
+from exo.inference.hf_helpers import download_all_files, HFRepoProgressEvent, HFRepoFileProgressEvent
+
+DEFAULT_ALLOW_PATTERNS = [
+    "*.json",
+    "*.py",
+    "tokenizer.model",
+    "*.tiktoken",
+    "*.txt",
+    "*.safetensors",
+]
+# Always ignore `.git` and `.cache/huggingface` folders in commits
+DEFAULT_IGNORE_PATTERNS = [
+    ".git",
+    ".git/*",
+    "*/.git",
+    "**/.git/**",
+    ".cache/huggingface",
+    ".cache/huggingface/*",
+    "*/.cache/huggingface",
+    "**/.cache/huggingface/**",
+]
+
+async def main(repo_id, revision="main", allow_patterns=None, ignore_patterns=None):
+    async def progress_callback(event: HFRepoProgressEvent):
+        print(f"Overall Progress: {event.completed_files}/{event.total_files} files, {event.downloaded_bytes}/{event.total_bytes} bytes")
+        print(f"Estimated time remaining: {event.overall_eta}")
+        print("File Progress:")
+        for file_path, progress in event.file_progress.items():
+            status_icon = {
+                'not_started': '⚪',
+                'in_progress': '🔵',
+                'complete': '✅'
+            }[progress.status]
+            eta_str = str(progress.eta)
+            print(f"{status_icon} {file_path}: {progress.downloaded}/{progress.total} bytes, "
+                  f"Speed: {progress.speed:.2f} B/s, ETA: {eta_str}")
+        print("\n")
+
+    await download_all_files(repo_id, revision, progress_callback, allow_patterns, ignore_patterns)
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description="Download files from a Hugging Face model repository.")
+    parser.add_argument("--repo-id", required=True, help="The repository ID (e.g., 'meta-llama/Meta-Llama-3.1-8B-Instruct')")
+    parser.add_argument("--revision", default="main", help="The revision to download (branch, tag, or commit hash)")
+    parser.add_argument("--allow-patterns", nargs="*", default=None, help="Patterns of files to allow (e.g., '*.json' '*.safetensors')")
+    parser.add_argument("--ignore-patterns", nargs="*", default=None, help="Patterns of files to ignore (e.g., '.*')")
+
+    args = parser.parse_args()
+
+    asyncio.run(main(args.repo_id, args.revision, args.allow_patterns, args.ignore_patterns))

+ 1 - 1
main.py

@@ -60,7 +60,7 @@ node.on_token.register("main_log").on_next(lambda _, tokens, __: print(inference
 if args.prometheus_client_port:
 if args.prometheus_client_port:
     from exo.stats.metrics import start_metrics_server
     from exo.stats.metrics import start_metrics_server
     start_metrics_server(node, args.prometheus_client_port)
     start_metrics_server(node, args.prometheus_client_port)
-inference_engine.set_on_download_progress(lambda current, total: asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "current": current, "total": total}))))
+inference_engine.set_progress_callback(lambda event: asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "current": event.downloaded_bytes, "total": event.total_bytes}))))
 
 
 async def shutdown(signal, loop):
 async def shutdown(signal, loop):
     """Gracefully shutdown the server and close the asyncio loop."""
     """Gracefully shutdown the server and close the asyncio loop."""

+ 2 - 0
setup.py

@@ -6,6 +6,7 @@ from setuptools import find_packages, setup
 install_requires = [
 install_requires = [
     "aiohttp==3.9.5",
     "aiohttp==3.9.5",
     "aiohttp_cors==0.7.0",
     "aiohttp_cors==0.7.0",
+    "aiofiles==24.1.0",
     "blobfile==2.1.1",
     "blobfile==2.1.1",
     "grpcio==1.64.1",
     "grpcio==1.64.1",
     "grpcio-tools==1.64.1",
     "grpcio-tools==1.64.1",
@@ -21,6 +22,7 @@ install_requires = [
     "requests==2.32.3",
     "requests==2.32.3",
     "rich==13.7.1",
     "rich==13.7.1",
     "safetensors==0.4.3",
     "safetensors==0.4.3",
+    "tenacity==9.0.0",
     "tiktoken==0.7.0",
     "tiktoken==0.7.0",
     "tokenizers==0.19.1",
     "tokenizers==0.19.1",
     "tqdm==4.66.4",
     "tqdm==4.66.4",