Browse Source

Merge pull request #640 from exo-explore/simpledownload

Simple download
Alex Cheema 7 months ago
parent
commit
a0d673fa3a

+ 2 - 2
.circleci/config.yml

@@ -27,7 +27,7 @@ commands:
             fi
             fi
 
 
             # Start first instance
             # Start first instance
-            HF_HOME="$(pwd)/.hf_cache_node1" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <<parameters.inference_engine>> \
+            EXO_HOME="$(pwd)/.exo_cache_node1" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <<parameters.inference_engine>> \
               --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 \
               --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 \
               --chatgpt-api-response-timeout 900 --disable-tui > output1.log &
               --chatgpt-api-response-timeout 900 --disable-tui > output1.log &
             PID1=$!
             PID1=$!
@@ -35,7 +35,7 @@ commands:
             TAIL1=$!
             TAIL1=$!
 
 
             # Start second instance
             # Start second instance
-            HF_HOME="$(pwd)/.hf_cache_node2" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <<parameters.inference_engine>> \
+            EXO_HOME="$(pwd)/.exo_cache_node2" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <<parameters.inference_engine>> \
               --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 \
               --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 \
               --chatgpt-api-response-timeout 900 --disable-tui > output2.log &
               --chatgpt-api-response-timeout 900 --disable-tui > output2.log &
             PID2=$!
             PID2=$!

+ 1 - 0
.github/workflows/bench_job.yml

@@ -128,6 +128,7 @@ jobs:
             --interface-type-filter="${{ inputs.network_interface }}" \
             --interface-type-filter="${{ inputs.network_interface }}" \
             --disable-tui \
             --disable-tui \
             --max-generate-tokens 250 \
             --max-generate-tokens 250 \
+            --chatgpt-api-response-timeout 900 \
             --chatgpt-api-port 52415 > output1.log 2>&1 &
             --chatgpt-api-port 52415 > output1.log 2>&1 &
           PID1=$!
           PID1=$!
           
           

+ 2 - 2
README.md

@@ -212,9 +212,9 @@ exo run llama-3.2-3b --prompt "What is the meaning of exo?"
 
 
 ### Model Storage
 ### Model Storage
 
 
-Models by default are stored in `~/.cache/huggingface/hub`.
+Models by default are stored in `~/.cache/exo/downloads`.
 
 
-You can set a different model storage location by setting the `HF_HOME` env var.
+You can set a different model storage location by setting the `EXO_HOME` env var.
 
 
 ## Debugging
 ## Debugging
 
 

+ 21 - 79
exo/api/chatgpt_api.py

@@ -11,30 +11,27 @@ import aiohttp_cors
 import traceback
 import traceback
 import signal
 import signal
 from exo import DEBUG, VERSION
 from exo import DEBUG, VERSION
-from exo.download.download_progress import RepoProgressEvent
 from exo.helpers import PrefixDict, shutdown, get_exo_images_dir
 from exo.helpers import PrefixDict, shutdown, get_exo_images_dir
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration import Node
 from exo.orchestration import Node
-from exo.models import build_base_shard, model_cards, get_repo, pretty_name
+from exo.models import build_base_shard, build_full_shard, model_cards, get_repo, get_supported_models, get_pretty_name
 from typing import Callable, Optional
 from typing import Callable, Optional
 from PIL import Image
 from PIL import Image
 import numpy as np
 import numpy as np
 import base64
 import base64
 from io import BytesIO
 from io import BytesIO
 import platform
 import platform
+from exo.download.download_progress import RepoProgressEvent
+from exo.download.new_shard_download import delete_model
+import tempfile
+from exo.apputil import create_animation_mp4
+from collections import defaultdict
 
 
 if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
 if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
   import mlx.core as mx
   import mlx.core as mx
 else:
 else:
   import numpy as mx
   import numpy as mx
 
 
-import tempfile
-from exo.download.hf.hf_shard_download import HFShardDownloader
-import shutil
-from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
-from exo.apputil import create_animation_mp4
-from collections import defaultdict
-
 
 
 class Message:
 class Message:
   def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]], tools: Optional[List[Dict]] = None):
   def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]], tools: Optional[List[Dict]] = None):
@@ -277,41 +274,12 @@ class ChatGPTAPI:
 
 
   async def handle_model_support(self, request):
   async def handle_model_support(self, request):
     try:
     try:
-      response = web.StreamResponse(status=200, reason='OK', headers={
-        'Content-Type': 'text/event-stream',
-        'Cache-Control': 'no-cache',
-        'Connection': 'keep-alive',
-      })
+      response = web.StreamResponse(status=200, reason='OK', headers={ 'Content-Type': 'text/event-stream', 'Cache-Control': 'no-cache', 'Connection': 'keep-alive' })
       await response.prepare(request)
       await response.prepare(request)
-
-      async def process_model(model_name, pretty):
-        if model_name in model_cards:
-          model_info = model_cards[model_name]
-
-          if self.inference_engine_classname in model_info.get("repo", {}):
-            shard = build_base_shard(model_name, self.inference_engine_classname)
-            if shard:
-              downloader = HFShardDownloader(quick_check=True)
-              downloader.current_shard = shard
-              downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
-              status = await downloader.get_shard_download_status()
-
-              download_percentage = status.get("overall") if status else None
-              total_size = status.get("total_size") if status else None
-              total_downloaded = status.get("total_downloaded") if status else False
-
-              model_data = {
-                model_name: {
-                  "name": pretty, "downloaded": download_percentage == 100 if download_percentage is not None else False, "download_percentage": download_percentage, "total_size": total_size,
-                  "total_downloaded": total_downloaded
-                }
-              }
-
-              await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
-
-      # Process all models in parallel
-      await asyncio.gather(*[process_model(model_name, pretty) for model_name, pretty in pretty_name.items()])
-
+      downloads = await self.node.shard_downloader.get_shard_download_status(self.inference_engine_classname)
+      for (path, d) in downloads:
+        model_data = { d.shard.model_id: { "downloaded": d.downloaded_bytes == d.total_bytes, "download_percentage": 100 if d.downloaded_bytes == d.total_bytes else 100 * float(d.downloaded_bytes) / float(d.total_bytes), "total_size": d.total_bytes, "total_downloaded": d.downloaded_bytes } }
+        await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
       await response.write(b"data: [DONE]\n\n")
       await response.write(b"data: [DONE]\n\n")
       return response
       return response
 
 
@@ -348,6 +316,7 @@ class ChatGPTAPI:
     progress_data = {}
     progress_data = {}
     for node_id, progress_event in self.node.node_download_progress.items():
     for node_id, progress_event in self.node.node_download_progress.items():
       if isinstance(progress_event, RepoProgressEvent):
       if isinstance(progress_event, RepoProgressEvent):
+        if progress_event.status != "in_progress": continue
         progress_data[node_id] = progress_event.to_dict()
         progress_data[node_id] = progress_event.to_dict()
       else:
       else:
         print(f"Unknown progress event type: {type(progress_event)}. {progress_event}")
         print(f"Unknown progress event type: {type(progress_event)}. {progress_event}")
@@ -574,46 +543,19 @@ class ChatGPTAPI:
       return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
       return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
 
 
   async def handle_delete_model(self, request):
   async def handle_delete_model(self, request):
+    model_id = request.match_info.get('model_name')
     try:
     try:
-      model_name = request.match_info.get('model_name')
-      if DEBUG >= 2: print(f"Attempting to delete model: {model_name}")
-
-      if not model_name or model_name not in model_cards:
-        return web.json_response({"detail": f"Invalid model name: {model_name}"}, status=400)
-
-      shard = build_base_shard(model_name, self.inference_engine_classname)
-      if not shard:
-        return web.json_response({"detail": "Could not build shard for model"}, status=400)
-
-      repo_id = get_repo(shard.model_id, self.inference_engine_classname)
-      if DEBUG >= 2: print(f"Repo ID for model: {repo_id}")
-
-      # Get the HF cache directory using the helper function
-      hf_home = get_hf_home()
-      cache_dir = get_repo_root(repo_id)
-
-      if DEBUG >= 2: print(f"Looking for model files in: {cache_dir}")
-
-      if os.path.exists(cache_dir):
-        if DEBUG >= 2: print(f"Found model files at {cache_dir}, deleting...")
-        try:
-          shutil.rmtree(cache_dir)
-          return web.json_response({"status": "success", "message": f"Model {model_name} deleted successfully", "path": str(cache_dir)})
-        except Exception as e:
-          return web.json_response({"detail": f"Failed to delete model files: {str(e)}"}, status=500)
-      else:
-        return web.json_response({"detail": f"Model files not found at {cache_dir}"}, status=404)
-
+      if await delete_model(model_id, self.inference_engine_classname): return web.json_response({"status": "success", "message": f"Model {model_id} deleted successfully"})
+      else: return web.json_response({"detail": f"Model {model_id} files not found"}, status=404)
     except Exception as e:
     except Exception as e:
-      print(f"Error in handle_delete_model: {str(e)}")
-      traceback.print_exc()
-      return web.json_response({"detail": f"Server error: {str(e)}"}, status=500)
+      if DEBUG >= 2: traceback.print_exc()
+      return web.json_response({"detail": f"Error deleting model: {str(e)}"}, status=500)
 
 
   async def handle_get_initial_models(self, request):
   async def handle_get_initial_models(self, request):
     model_data = {}
     model_data = {}
-    for model_name, pretty in pretty_name.items():
-      model_data[model_name] = {
-        "name": pretty,
+    for model_id in get_supported_models([[self.inference_engine_classname]]):
+      model_data[model_id] = {
+        "name": get_pretty_name(model_id),
         "downloaded": None,  # Initially unknown
         "downloaded": None,  # Initially unknown
         "download_percentage": None,  # Change from 0 to null
         "download_percentage": None,  # Change from 0 to null
         "total_size": None,
         "total_size": None,
@@ -659,7 +601,7 @@ class ChatGPTAPI:
       model_name = data.get("model")
       model_name = data.get("model")
       if not model_name: return web.json_response({"error": "model parameter is required"}, status=400)
       if not model_name: return web.json_response({"error": "model parameter is required"}, status=400)
       if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400)
       if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400)
-      shard = build_base_shard(model_name, self.inference_engine_classname)
+      shard = build_full_shard(model_name, self.inference_engine_classname)
       if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
       if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
       asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard, self.inference_engine_classname))
       asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard, self.inference_engine_classname))
 
 

+ 6 - 2
exo/download/download_progress.py

@@ -1,4 +1,5 @@
 from typing import Dict, Callable, Coroutine, Any, Literal
 from typing import Dict, Callable, Coroutine, Any, Literal
+from exo.inference.shard import Shard
 from dataclasses import dataclass
 from dataclasses import dataclass
 from datetime import timedelta
 from datetime import timedelta
 
 
@@ -14,11 +15,12 @@ class RepoFileProgressEvent:
   speed: int
   speed: int
   eta: timedelta
   eta: timedelta
   status: Literal["not_started", "in_progress", "complete"]
   status: Literal["not_started", "in_progress", "complete"]
+  start_time: float
 
 
   def to_dict(self):
   def to_dict(self):
     return {
     return {
       "repo_id": self.repo_id, "repo_revision": self.repo_revision, "file_path": self.file_path, "downloaded": self.downloaded, "downloaded_this_session": self.downloaded_this_session,
       "repo_id": self.repo_id, "repo_revision": self.repo_revision, "file_path": self.file_path, "downloaded": self.downloaded, "downloaded_this_session": self.downloaded_this_session,
-      "total": self.total, "speed": self.speed, "eta": self.eta.total_seconds(), "status": self.status
+      "total": self.total, "speed": self.speed, "eta": self.eta.total_seconds(), "status": self.status, "start_time": self.start_time
     }
     }
 
 
   @classmethod
   @classmethod
@@ -29,6 +31,7 @@ class RepoFileProgressEvent:
 
 
 @dataclass
 @dataclass
 class RepoProgressEvent:
 class RepoProgressEvent:
+  shard: Shard
   repo_id: str
   repo_id: str
   repo_revision: str
   repo_revision: str
   completed_files: int
   completed_files: int
@@ -43,7 +46,7 @@ class RepoProgressEvent:
 
 
   def to_dict(self):
   def to_dict(self):
     return {
     return {
-      "repo_id": self.repo_id, "repo_revision": self.repo_revision, "completed_files": self.completed_files, "total_files": self.total_files, "downloaded_bytes": self.downloaded_bytes,
+      "shard": self.shard.to_dict(), "repo_id": self.repo_id, "repo_revision": self.repo_revision, "completed_files": self.completed_files, "total_files": self.total_files, "downloaded_bytes": self.downloaded_bytes,
       "downloaded_bytes_this_session": self.downloaded_bytes_this_session, "total_bytes": self.total_bytes, "overall_speed": self.overall_speed, "overall_eta": self.overall_eta.total_seconds(),
       "downloaded_bytes_this_session": self.downloaded_bytes_this_session, "total_bytes": self.total_bytes, "overall_speed": self.overall_speed, "overall_eta": self.overall_eta.total_seconds(),
       "file_progress": {k: v.to_dict()
       "file_progress": {k: v.to_dict()
                         for k, v in self.file_progress.items()}, "status": self.status
                         for k, v in self.file_progress.items()}, "status": self.status
@@ -53,6 +56,7 @@ class RepoProgressEvent:
   def from_dict(cls, data):
   def from_dict(cls, data):
     if 'overall_eta' in data: data['overall_eta'] = timedelta(seconds=data['overall_eta'])
     if 'overall_eta' in data: data['overall_eta'] = timedelta(seconds=data['overall_eta'])
     if 'file_progress' in data: data['file_progress'] = {k: RepoFileProgressEvent.from_dict(v) for k, v in data['file_progress'].items()}
     if 'file_progress' in data: data['file_progress'] = {k: RepoFileProgressEvent.from_dict(v) for k, v in data['file_progress'].items()}
+    if 'shard' in data: data['shard'] = Shard.from_dict(data['shard'])
 
 
     return cls(**data)
     return cls(**data)
 
 

+ 3 - 412
exo/download/hf/hf_helpers.py

@@ -1,36 +1,16 @@
 import aiofiles.os as aios
 import aiofiles.os as aios
 from typing import Union
 from typing import Union
-import asyncio
-import aiohttp
-import json
 import os
 import os
-import sys
-import shutil
-from urllib.parse import urljoin
-from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
-from datetime import datetime, timedelta
+from typing import Callable, Optional, Dict, List, Union
 from fnmatch import fnmatch
 from fnmatch import fnmatch
 from pathlib import Path
 from pathlib import Path
-from typing import Generator, Iterable, TypeVar, TypedDict
-from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
-from exo.helpers import DEBUG, is_frozen
-from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback
+from typing import Generator, Iterable, TypeVar
+from exo.helpers import DEBUG
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 import aiofiles
 import aiofiles
 
 
 T = TypeVar("T")
 T = TypeVar("T")
 
 
-async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]:
-  refs_dir = get_repo_root(repo_id)/"refs"
-  refs_file = refs_dir/revision
-  if await aios.path.exists(refs_file):
-    async with aiofiles.open(refs_file, 'r') as f:
-      commit_hash = (await f.read()).strip()
-      snapshot_dir = get_repo_root(repo_id)/"snapshots"/commit_hash
-      return snapshot_dir
-  return None
-
-
 def filter_repo_objects(
 def filter_repo_objects(
   items: Iterable[T],
   items: Iterable[T],
   *,
   *,
@@ -48,14 +28,12 @@ def filter_repo_objects(
     ignore_patterns = [_add_wildcard_to_directories(p) for p in ignore_patterns]
     ignore_patterns = [_add_wildcard_to_directories(p) for p in ignore_patterns]
 
 
   if key is None:
   if key is None:
-
     def _identity(item: T) -> str:
     def _identity(item: T) -> str:
       if isinstance(item, str):
       if isinstance(item, str):
         return item
         return item
       if isinstance(item, Path):
       if isinstance(item, Path):
         return str(item)
         return str(item)
       raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.")
       raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.")
-
     key = _identity
     key = _identity
 
 
   for item in items:
   for item in items:
@@ -66,22 +44,18 @@ def filter_repo_objects(
       continue
       continue
     yield item
     yield item
 
 
-
 def _add_wildcard_to_directories(pattern: str) -> str:
 def _add_wildcard_to_directories(pattern: str) -> str:
   if pattern[-1] == "/":
   if pattern[-1] == "/":
     return pattern + "*"
     return pattern + "*"
   return pattern
   return pattern
 
 
-
 def get_hf_endpoint() -> str:
 def get_hf_endpoint() -> str:
   return os.environ.get('HF_ENDPOINT', "https://huggingface.co")
   return os.environ.get('HF_ENDPOINT', "https://huggingface.co")
 
 
-
 def get_hf_home() -> Path:
 def get_hf_home() -> Path:
   """Get the Hugging Face home directory."""
   """Get the Hugging Face home directory."""
   return Path(os.environ.get("HF_HOME", Path.home()/".cache"/"huggingface"))
   return Path(os.environ.get("HF_HOME", Path.home()/".cache"/"huggingface"))
 
 
-
 async def get_hf_token():
 async def get_hf_token():
   """Retrieve the Hugging Face token from the user's HF_HOME directory."""
   """Retrieve the Hugging Face token from the user's HF_HOME directory."""
   token_path = get_hf_home()/"token"
   token_path = get_hf_home()/"token"
@@ -90,7 +64,6 @@ async def get_hf_token():
       return (await f.read()).strip()
       return (await f.read()).strip()
   return None
   return None
 
 
-
 async def get_auth_headers():
 async def get_auth_headers():
   """Get authentication headers if a token is available."""
   """Get authentication headers if a token is available."""
   token = await get_hf_token()
   token = await get_hf_token()
@@ -98,325 +71,6 @@ async def get_auth_headers():
     return {"Authorization": f"Bearer {token}"}
     return {"Authorization": f"Bearer {token}"}
   return {}
   return {}
 
 
-
-def get_repo_root(repo_id: str) -> Path:
-  """Get the root directory for a given repo ID in the Hugging Face cache."""
-  sanitized_repo_id = str(repo_id).replace("/", "--")
-  return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"
-
-async def move_models_to_hf(seed_dir: Union[str, Path]):
-  """Move model in resources folder of app to .cache/huggingface/hub"""
-  source_dir = Path(seed_dir)
-  dest_dir = get_hf_home()/"hub"
-  await aios.makedirs(dest_dir, exist_ok=True)  
-  for path in source_dir.iterdir():
-    if path.is_dir() and path.name.startswith("models--"):
-      dest_path = dest_dir / path.name
-      if await aios.path.exists(dest_path):
-        print('Skipping moving model to .cache directory')
-      else:
-        try:
-          await aios.rename(str(path), str(dest_path))
-        except Exception as e:
-          print(f'Error moving model to .cache: {e}')
-    
-    
-    
-async def fetch_file_list(session, repo_id, revision, path=""):
-  api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
-  url = f"{api_url}/{path}" if path else api_url
-
-  headers = await get_auth_headers()
-  async with session.get(url, headers=headers) as response:
-    if response.status == 200:
-      data = await response.json()
-      files = []
-      for item in data:
-        if item["type"] == "file":
-          files.append({"path": item["path"], "size": item["size"]})
-        elif item["type"] == "directory":
-          subfiles = await fetch_file_list(session, repo_id, revision, item["path"])
-          files.extend(subfiles)
-      return files
-    else:
-      raise Exception(f"Failed to fetch file list: {response.status}")
-
-
-@retry(
-  stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=60), retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError, aiohttp.ClientResponseError)), reraise=True
-)
-async def download_file(
-  session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[RepoFileProgressCallback] = None, use_range_request: bool = True
-):
-  base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
-  url = urljoin(base_url, file_path)
-  local_path = os.path.join(save_directory, file_path)
-
-  await aios.makedirs(os.path.dirname(local_path), exist_ok=True)
-
-  # Check if file already exists and get its size
-  local_file_size = await aios.path.getsize(local_path) if await aios.path.exists(local_path) else 0
-
-  headers = await get_auth_headers()
-  if use_range_request:
-    headers["Range"] = f"bytes={local_file_size}-"
-
-  async with session.get(url, headers=headers) as response:
-    total_size = int(response.headers.get('Content-Length', 0))
-    downloaded_size = local_file_size
-    downloaded_this_session = 0
-    mode = 'ab' if use_range_request else 'wb'
-    percentage = await get_file_download_percentage(
-      session,
-      repo_id,
-      revision,
-      file_path,
-      Path(save_directory)
-    )
-    
-    if percentage == 100:
-      if DEBUG >= 2: print(f"File already downloaded: {file_path}")
-      if progress_callback:
-        await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, total_size, 0, total_size, 0, timedelta(0), "complete"))
-      return
-
-    if response.status == 200:
-      # File doesn't support range requests or we're not using them, start from beginning
-      mode = 'wb'
-      downloaded_size = 0
-    elif response.status == 206:
-      # Partial content, resume download
-      content_range = response.headers.get('Content-Range', '')
-      try:
-        total_size = int(content_range.split('/')[-1])
-      except ValueError:
-        if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
-        return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
-    elif response.status == 416:
-      # Range not satisfiable, get the actual file size
-      content_range = response.headers.get('Content-Range', '')
-      try:
-        total_size = int(content_range.split('/')[-1])
-        if downloaded_size == total_size:
-          if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}")
-          if progress_callback:
-            await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
-          return
-      except ValueError:
-        if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
-        return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
-    else:
-      raise aiohttp.ClientResponseError(response.request_info, response.history, status=response.status, message=f"Failed to download {file_path}: {response.status}")
-
-    if downloaded_size == total_size:
-      print(f"File already downloaded: {file_path}")
-      if progress_callback:
-        await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
-      return
-
-    DOWNLOAD_CHUNK_SIZE = 32768
-    start_time = datetime.now()
-    async with aiofiles.open(local_path, mode) as f:
-      async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
-        await f.write(chunk)
-        downloaded_size += len(chunk)
-        downloaded_this_session += len(chunk)
-        if progress_callback and total_size:
-          elapsed_time = (datetime.now() - start_time).total_seconds()
-          speed = int(downloaded_this_session/elapsed_time) if elapsed_time > 0 else 0
-          remaining_size = total_size - downloaded_size
-          eta = timedelta(seconds=remaining_size/speed) if speed > 0 else timedelta(0)
-          status = "in_progress" if downloaded_size < total_size else "complete"
-          if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}")
-          await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
-    if DEBUG >= 2: print(f"Downloaded: {file_path}")
-
-
-async def resolve_revision_to_commit_hash(repo_id: str, revision: str) -> str:
-  repo_root = get_repo_root(repo_id)
-  refs_dir = repo_root/"refs"
-  refs_file = refs_dir/revision
-
-  # Check if we have a cached commit hash
-  if await aios.path.exists(refs_file):
-    async with aiofiles.open(refs_file, 'r') as f:
-      commit_hash = (await f.read()).strip()
-      if DEBUG >= 2: print(f"Commit hash is already cached at {refs_file}: {commit_hash}")
-      return commit_hash
-
-  # Fetch the commit hash for the given revision
-  async with aiohttp.ClientSession() as session:
-    api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/revision/{revision}"
-    headers = await get_auth_headers()
-    async with session.get(api_url, headers=headers) as response:
-      if response.status != 200:
-        raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
-      revision_info = await response.json()
-      commit_hash = revision_info['sha']
-
-  # Cache the commit hash
-  await aios.makedirs(refs_dir, exist_ok=True)
-  async with aiofiles.open(refs_file, 'w') as f:
-    await f.write(commit_hash)
-
-  return commit_hash
-
-
-async def download_repo_files(
-  repo_id: str,
-  revision: str = "main",
-  progress_callback: Optional[RepoProgressCallback] = None,
-  allow_patterns: Optional[Union[List[str], str]] = None,
-  ignore_patterns: Optional[Union[List[str], str]] = None,
-  max_parallel_downloads: int = 4
-) -> Path:
-  repo_root = get_repo_root(repo_id)
-  snapshots_dir = repo_root/"snapshots"
-  cachedreqs_dir = repo_root/"cachedreqs"
-
-  # Ensure directories exist
-  await aios.makedirs(snapshots_dir, exist_ok=True)
-  await aios.makedirs(cachedreqs_dir, exist_ok=True)
-
-  # Resolve revision to commit hash
-  commit_hash = await resolve_revision_to_commit_hash(repo_id, revision)
-
-  # Set up the snapshot directory
-  snapshot_dir = snapshots_dir/commit_hash
-  await aios.makedirs(snapshot_dir, exist_ok=True)
-
-  # Set up the cached file list directory
-  cached_file_list_dir = cachedreqs_dir/commit_hash
-  await aios.makedirs(cached_file_list_dir, exist_ok=True)
-  cached_file_list_path = cached_file_list_dir/"fetch_file_list.json"
-
-  async with aiohttp.ClientSession() as session:
-    # Check if we have a cached file list
-    if await aios.path.exists(cached_file_list_path):
-      async with aiofiles.open(cached_file_list_path, 'r') as f:
-        file_list = json.loads(await f.read())
-      if DEBUG >= 2: print(f"Using cached file list from {cached_file_list_path}")
-    else:
-      file_list = await fetch_file_list(session, repo_id, revision)
-      # Cache the file list
-      async with aiofiles.open(cached_file_list_path, 'w') as f:
-        await f.write(json.dumps(file_list))
-      if DEBUG >= 2: print(f"Cached file list at {cached_file_list_path}")
-
-    model_index_exists = any(file["path"] == "model_index.json" for file in file_list)
-    if model_index_exists:
-      allow_patterns = ["**/*.json", "**/*.txt", "**/*model.safetensors", "*.json"]
-
-    filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
-    total_files = len(filtered_file_list)
-    total_bytes = sum(file["size"] for file in filtered_file_list)
-    file_progress: Dict[str, RepoFileProgressEvent] = {
-      file["path"]: RepoFileProgressEvent(repo_id, revision, file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started")
-      for file in filtered_file_list
-    }
-    start_time = datetime.now()
-
-    async def download_with_progress(file_info, progress_state):
-      local_path = snapshot_dir/file_info["path"]
-      if await aios.path.exists(local_path) and (await aios.stat(local_path)).st_size == file_info["size"]:
-        if DEBUG >= 2: print(f"File already fully downloaded: {file_info['path']}")
-        progress_state['completed_files'] += 1
-        progress_state['downloaded_bytes'] += file_info["size"]
-        file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], 0, file_info["size"], 0, timedelta(0), "complete")
-        if progress_callback:
-          elapsed_time = (datetime.now() - start_time).total_seconds()
-          overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
-          remaining_bytes = total_bytes - progress_state['downloaded_bytes']
-          overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
-          status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
-          await progress_callback(
-            RepoProgressEvent(
-              repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
-              overall_eta, file_progress, status
-            )
-          )
-        return
-
-      async def file_progress_callback(event: RepoFileProgressEvent):
-        progress_state['downloaded_bytes'] += event.downloaded - file_progress[event.file_path].downloaded
-        progress_state['downloaded_bytes_this_session'] += event.downloaded_this_session - file_progress[event.file_path].downloaded_this_session
-        file_progress[event.file_path] = event
-        if progress_callback:
-          elapsed_time = (datetime.now() - start_time).total_seconds()
-          overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
-          remaining_bytes = total_bytes - progress_state['downloaded_bytes']
-          overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
-          status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete"
-          await progress_callback(
-            RepoProgressEvent(
-              repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
-              overall_eta, file_progress, status
-            )
-          )
-
-      await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
-      progress_state['completed_files'] += 1
-      file_progress[
-        file_info["path"]
-      ] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete")
-      if progress_callback:
-        elapsed_time = (datetime.now() - start_time).total_seconds()
-        overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
-        remaining_bytes = total_bytes - progress_state['downloaded_bytes']
-        overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
-        status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
-        await progress_callback(
-          RepoProgressEvent(
-            repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
-            overall_eta, file_progress, status
-          )
-        )
-
-    progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0}
-
-    semaphore = asyncio.Semaphore(max_parallel_downloads)
-
-    async def download_with_semaphore(file_info):
-      async with semaphore:
-        await download_with_progress(file_info, progress_state)
-
-    tasks = [asyncio.create_task(download_with_semaphore(file_info)) for file_info in filtered_file_list]
-    await asyncio.gather(*tasks)
-
-  return snapshot_dir
-
-
-async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[str, str]]:
-  """
-    Retrieve the weight map from the model.safetensors.index.json file.
-
-    Args:
-        repo_id (str): The Hugging Face repository ID.
-        revision (str): The revision of the repository to use.
-
-    Returns:
-        Optional[Dict[str, str]]: The weight map if it exists, otherwise None.
-    """
-
-  # Download the index file
-  await download_repo_files(repo_id=repo_id, revision=revision, allow_patterns="model.safetensors.index.json")
-
-  # Check if the file exists
-  repo_root = get_repo_root(repo_id)
-  commit_hash = await resolve_revision_to_commit_hash(repo_id, revision)
-  snapshot_dir = repo_root/"snapshots"/commit_hash
-  index_file = next((f for f in await aios.listdir(snapshot_dir) if f.endswith("model.safetensors.index.json")), None)
-
-  if index_file:
-    index_file_path = snapshot_dir/index_file
-    if await aios.path.exists(index_file_path):
-      async with aiofiles.open(index_file_path, 'r') as f:
-        index_data = json.loads(await f.read())
-      return index_data.get("weight_map")
-
-  return None
-
-
 def extract_layer_num(tensor_name: str) -> Optional[int]:
 def extract_layer_num(tensor_name: str) -> Optional[int]:
   # This is a simple example and might need to be adjusted based on the actual naming convention
   # This is a simple example and might need to be adjusted based on the actual naming convention
   parts = tensor_name.split('.')
   parts = tensor_name.split('.')
@@ -425,7 +79,6 @@ def extract_layer_num(tensor_name: str) -> Optional[int]:
       return int(part)
       return int(part)
   return None
   return None
 
 
-
 def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
 def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
   default_patterns = set(["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"])
   default_patterns = set(["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"])
   shard_specific_patterns = set()
   shard_specific_patterns = set()
@@ -443,65 +96,3 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
     shard_specific_patterns = set(["*.safetensors"])
     shard_specific_patterns = set(["*.safetensors"])
   if DEBUG >= 3: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
   if DEBUG >= 3: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
   return list(default_patterns | shard_specific_patterns)
   return list(default_patterns | shard_specific_patterns)
-
-async def get_file_download_percentage(
-    session: aiohttp.ClientSession,
-    repo_id: str,
-    revision: str,
-    file_path: str,
-    snapshot_dir: Path,
-) -> float:
-  """
-    Calculate the download percentage for a file by comparing local and remote sizes.
-    """
-  try:
-    local_path = snapshot_dir / file_path
-    if not await aios.path.exists(local_path):
-      return 0
-
-    # Get local file size first
-    local_size = await aios.path.getsize(local_path)
-    if local_size == 0:
-      return 0
-
-    # Check remote size
-    base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
-    url = urljoin(base_url, file_path)
-    headers = await get_auth_headers()
-
-    # Use HEAD request with redirect following for all files
-    async with session.head(url, headers=headers, allow_redirects=True) as response:
-      if response.status != 200:
-        if DEBUG >= 2:
-          print(f"Failed to get remote file info for {file_path}: {response.status}")
-        return 0
-
-      remote_size = int(response.headers.get('Content-Length', 0))
-
-      if remote_size == 0:
-        if DEBUG >= 2:
-          print(f"Remote size is 0 for {file_path}")
-        return 0
-
-      # Only return 100% if sizes match exactly
-      if local_size == remote_size:
-        return 100.0
-
-      # Calculate percentage based on sizes
-      return (local_size / remote_size) * 100 if remote_size > 0 else 0
-
-  except Exception as e:
-    if DEBUG >= 2:
-      print(f"Error checking file download status for {file_path}: {e}")
-    return 0
-
-async def has_hf_home_read_access() -> bool:
-  hf_home = get_hf_home()
-  try: return await aios.access(hf_home, os.R_OK)
-  except OSError: return False
-
-async def has_hf_home_write_access() -> bool:
-  hf_home = get_hf_home()
-  try: return await aios.access(hf_home, os.W_OK)
-  except OSError: return False
-

+ 0 - 172
exo/download/hf/hf_shard_download.py

@@ -1,172 +0,0 @@
-import asyncio
-import traceback
-from pathlib import Path
-from typing import Dict, List, Tuple, Optional, Union
-from exo.inference.shard import Shard
-from exo.download.shard_download import ShardDownloader
-from exo.download.download_progress import RepoProgressEvent
-from exo.download.hf.hf_helpers import (
-    download_repo_files, RepoProgressEvent, get_weight_map, 
-    get_allow_patterns, get_repo_root, fetch_file_list, 
-    get_local_snapshot_dir, get_file_download_percentage,
-    filter_repo_objects
-)
-from exo.helpers import AsyncCallbackSystem, DEBUG
-from exo.models import model_cards, get_repo
-import aiohttp
-from aiofiles import os as aios
-
-
-class HFShardDownloader(ShardDownloader):
-  def __init__(self, quick_check: bool = False, max_parallel_downloads: int = 4):
-    self.quick_check = quick_check
-    self.max_parallel_downloads = max_parallel_downloads
-    self.active_downloads: Dict[Shard, asyncio.Task] = {}
-    self.completed_downloads: Dict[Shard, Path] = {}
-    self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
-    self.current_shard: Optional[Shard] = None
-    self.current_repo_id: Optional[str] = None
-    self.revision: str = "main"
-
-  async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
-    self.current_shard = shard
-    self.current_repo_id = get_repo(shard.model_id, inference_engine_name)
-    repo_name = get_repo(shard.model_id, inference_engine_name)
-    if shard in self.completed_downloads:
-      return self.completed_downloads[shard]
-    if self.quick_check:
-      repo_root = get_repo_root(repo_name)
-      snapshots_dir = repo_root/"snapshots"
-      if snapshots_dir.exists():
-        visible_dirs = [d for d in snapshots_dir.iterdir() if not d.name.startswith('.')]
-        if visible_dirs:
-          most_recent_dir = max(visible_dirs, key=lambda x: x.stat().st_mtime)
-          return most_recent_dir
-
-    # If a download on this shard is already in progress, keep that one
-    for active_shard in self.active_downloads:
-      if active_shard == shard:
-        if DEBUG >= 2: print(f"Download already in progress for {shard}. Keeping that one.")
-        return await self.active_downloads[shard]
-
-    # Cancel any downloads for this model_id on a different shard
-    existing_active_shards = [active_shard for active_shard in self.active_downloads.keys() if active_shard.model_id == shard.model_id]
-    for active_shard in existing_active_shards:
-      if DEBUG >= 2: print(f"Cancelling download for {active_shard} (replacing with {shard})")
-      task = self.active_downloads[active_shard]
-      task.cancel()
-      try:
-        await task
-      except asyncio.CancelledError:
-        pass  # This is expected when cancelling a task
-      except Exception as e:
-        if DEBUG >= 2: print(f"Error in cancelling download {active_shard}: {e}")
-        traceback.print_exc()
-    self.active_downloads = {active_shard: task for active_shard, task in self.active_downloads.items() if active_shard.model_id != shard.model_id}
-
-    # Start new download
-    download_task = asyncio.create_task(self._download_shard(shard, repo_name))
-    self.active_downloads[shard] = download_task
-    try:
-      path = await download_task
-      self.completed_downloads[shard] = path
-      return path
-    finally:
-      # Ensure the task is removed even if an exception occurs
-      print(f"Removing download task for {shard}: {shard in self.active_downloads}")
-      if shard in self.active_downloads:
-        self.active_downloads.pop(shard)
-
-  async def _download_shard(self, shard: Shard, repo_name: str) -> Path:
-    async def wrapped_progress_callback(event: RepoProgressEvent):
-      self._on_progress.trigger_all(shard, event)
-
-    weight_map = await get_weight_map(repo_name)
-    allow_patterns = get_allow_patterns(weight_map, shard)
-
-    return await download_repo_files(repo_name, progress_callback=wrapped_progress_callback, allow_patterns=allow_patterns, max_parallel_downloads=self.max_parallel_downloads)
-
-  @property
-  def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
-    return self._on_progress
-
-  async def get_shard_download_status(self) -> Optional[Dict[str, Union[float, int]]]:
-    if not self.current_shard or not self.current_repo_id:
-      if DEBUG >= 2:
-        print(f"No current shard or repo_id set: {self.current_shard=} {self.current_repo_id=}")
-      return None
-
-    try:
-      # If no snapshot directory exists, return None - no need to check remote files
-      snapshot_dir = await get_local_snapshot_dir(self.current_repo_id, self.revision)
-      if not snapshot_dir:
-        if DEBUG >= 2:
-          print(f"No snapshot directory found for {self.current_repo_id}")
-        return None
-
-      if not await aios.path.exists(snapshot_dir/"model_index.json"):
-      # Get the weight map to know what files we need
-        weight_map = await get_weight_map(self.current_repo_id, self.revision)
-        if not weight_map:
-          if DEBUG >= 2:
-            print(f"No weight map found for {self.current_repo_id}")
-          return None
-
-        # Get all files needed for this shard
-        patterns = get_allow_patterns(weight_map, self.current_shard)
-      else:
-        patterns = ["**/*.json", "**/*.txt", "**/*model.safetensors", "*.json"]
-
-
-      # Check download status for all relevant files
-      status = {}
-      total_bytes = 0
-      downloaded_bytes = 0
-
-      async with aiohttp.ClientSession() as session:
-        file_list = await fetch_file_list(session, self.current_repo_id, self.revision)
-        relevant_files = list(
-            filter_repo_objects(
-                file_list, allow_patterns=patterns, key=lambda x: x["path"]))
-
-        for file in relevant_files:
-          file_size = file["size"]
-          total_bytes += file_size
-
-          percentage = await get_file_download_percentage(
-              session,
-              self.current_repo_id,
-              self.revision,
-              file["path"],
-              snapshot_dir,
-          )
-          status[file["path"]] = percentage
-          downloaded_bytes += (file_size * (percentage / 100))
-
-        # Add overall progress weighted by file size
-        if total_bytes > 0:
-          status["overall"] = (downloaded_bytes / total_bytes) * 100
-        else:
-          status["overall"] = 0
-          
-        # Add total size in bytes
-        status["total_size"] = total_bytes
-        if status["overall"] != 100:
-          status["total_downloaded"] = downloaded_bytes
-        
-
-        if DEBUG >= 2:
-          print(f"Download calculation for {self.current_repo_id}:")
-          print(f"Total bytes: {total_bytes}")
-          print(f"Downloaded bytes: {downloaded_bytes}")
-        if DEBUG >= 3:
-          for file in relevant_files:
-            print(f"File {file['path']}: size={file['size']}, percentage={status[file['path']]}")
-
-      return status
-
-    except Exception as e:
-      if DEBUG >= 3:
-        print(f"Error getting shard download status: {e}")
-        traceback.print_exc()
-      return None

+ 226 - 0
exo/download/new_shard_download.py

@@ -0,0 +1,226 @@
+from exo.inference.shard import Shard
+from exo.models import get_repo
+from pathlib import Path
+from exo.download.hf.hf_helpers import get_hf_endpoint, get_auth_headers, filter_repo_objects, get_allow_patterns
+from exo.download.shard_download import ShardDownloader
+from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent
+from exo.helpers import AsyncCallbackSystem, DEBUG
+from exo.models import get_supported_models, build_full_shard
+import os
+import aiofiles.os as aios
+import aiohttp
+import aiofiles
+from urllib.parse import urljoin
+from typing import Callable, Union, Tuple, Dict, List
+import time
+from datetime import timedelta
+import asyncio
+import json
+import traceback
+import shutil
+import tempfile
+
+def exo_home() -> Path:
+  return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo"))
+
+def exo_tmp() -> Path:
+  return Path(tempfile.gettempdir())/"exo"
+
+async def ensure_exo_tmp() -> Path:
+  await aios.makedirs(exo_tmp(), exist_ok=True)
+  return exo_tmp()
+
+async def has_exo_home_read_access() -> bool:
+  try: return await aios.access(exo_home(), os.R_OK)
+  except OSError: return False
+
+async def has_exo_home_write_access() -> bool:
+  try: return await aios.access(exo_home(), os.W_OK)
+  except OSError: return False
+
+async def ensure_downloads_dir() -> Path:
+  downloads_dir = exo_home()/"downloads"
+  await aios.makedirs(downloads_dir, exist_ok=True)
+  return downloads_dir
+
+async def delete_model(model_id: str, inference_engine_name: str) -> bool:
+  repo_id = get_repo(model_id, inference_engine_name)
+  model_dir = await ensure_downloads_dir()/repo_id.replace("/", "--")
+  if not await aios.path.exists(model_dir): return False
+  await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
+  return True
+
+async def seed_models(seed_dir: Union[str, Path]):
+  """Move model in resources folder of app to .cache/huggingface/hub"""
+  source_dir = Path(seed_dir)
+  dest_dir = await ensure_downloads_dir()
+  for path in source_dir.iterdir():
+    if path.is_dir() and path.name.startswith("models--"):
+      dest_path = dest_dir/path.name
+      if await aios.path.exists(dest_path): print('Skipping moving model to .cache directory')
+      else:
+        try: await aios.rename(str(path), str(dest_path))
+        except:
+          print(f"Error seeding model {path} to {dest_path}")
+          traceback.print_exc()
+
+async def fetch_file_list(session, repo_id, revision, path=""):
+  api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
+  url = f"{api_url}/{path}" if path else api_url
+
+  headers = await get_auth_headers()
+  async with session.get(url, headers=headers) as response:
+    if response.status == 200:
+      data = await response.json()
+      files = []
+      for item in data:
+        if item["type"] == "file":
+          files.append({"path": item["path"], "size": item["size"]})
+        elif item["type"] == "directory":
+          subfiles = await fetch_file_list(session, repo_id, revision, item["path"])
+          files.extend(subfiles)
+      return files
+    else:
+      raise Exception(f"Failed to fetch file list: {response.status}")
+
+async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
+  if (target_dir/path).exists(): return target_dir/path
+  await aios.makedirs((target_dir/path).parent, exist_ok=True)
+  base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
+  url = urljoin(base_url, path)
+  headers = await get_auth_headers()
+  async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as r:
+    assert r.status == 200, f"Failed to download {path} from {url}: {r.status}"
+    length = int(r.headers.get('content-length', 0))
+    n_read = 0
+    async with aiofiles.tempfile.NamedTemporaryFile(dir=target_dir, delete=False) as temp_file:
+      while chunk := await r.content.read(1024 * 1024): on_progress(n_read := n_read + await temp_file.write(chunk), length)
+      await aios.rename(temp_file.name, target_dir/path)
+    return target_dir/path
+
+def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_progress: Dict[str, RepoFileProgressEvent], all_start_time: float) -> RepoProgressEvent:
+  all_total_bytes = sum([p.total for p in file_progress.values()])
+  all_downloaded_bytes = sum([p.downloaded for p in file_progress.values()])
+  all_downloaded_bytes_this_session = sum([p.downloaded_this_session for p in file_progress.values()])
+  elapsed_time = time.time() - all_start_time
+  all_speed = all_downloaded_bytes_this_session / elapsed_time if elapsed_time > 0 else 0
+  all_eta = timedelta(seconds=(all_total_bytes - all_downloaded_bytes) / all_speed) if all_speed > 0 else timedelta(seconds=0)
+  status = "not_started" if all_downloaded_bytes == 0 else "complete" if all_downloaded_bytes == all_total_bytes else "in_progress"
+  return RepoProgressEvent(shard, repo_id, revision, len([p for p in file_progress.values() if p.downloaded == p.total]), len(file_progress), all_downloaded_bytes, all_downloaded_bytes_this_session, all_total_bytes, all_speed, all_eta, file_progress, status)
+
+async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]:
+  target_dir = await ensure_exo_tmp()/repo_id.replace("/", "--")
+  async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=10, sock_read=1800, sock_connect=10)) as session:
+    index_file = await download_file(session, repo_id, revision, "model.safetensors.index.json", target_dir)
+    async with aiofiles.open(index_file, 'r') as f: index_data = json.loads(await f.read())
+    return index_data.get("weight_map")
+
+async def resolve_allow_patterns(shard: Shard, inference_engine_classname: str) -> List[str]:
+  try:
+    weight_map = await get_weight_map(get_repo(shard.model_id, inference_engine_classname))
+    return get_allow_patterns(weight_map, shard)
+  except:
+    if DEBUG >= 1: print(f"Error getting weight map for {shard.model_id=} and inference engine {inference_engine_classname}")
+    if DEBUG >= 1: traceback.print_exc()
+    return ["*"]
+
+async def download_shard(shard: Shard, inference_engine_classname: str, on_progress: AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]], max_parallel_downloads: int = 6, skip_download: bool = False) -> tuple[Path, RepoProgressEvent]:
+  if DEBUG >= 2 and not skip_download: print(f"Downloading {shard.model_id=} for {inference_engine_classname}")
+  repo_id = get_repo(shard.model_id, inference_engine_classname)
+  revision = "main"
+  target_dir = await ensure_downloads_dir()/repo_id.replace("/", "--")
+  if not skip_download: await aios.makedirs(target_dir, exist_ok=True)
+
+  if repo_id is None:
+    raise ValueError(f"No repo found for {shard.model_id=} and inference engine {inference_engine_classname}")
+
+  allow_patterns = await resolve_allow_patterns(shard, inference_engine_classname)
+  if DEBUG >= 2: print(f"Downloading {shard.model_id=} with {allow_patterns=}")
+
+  all_start_time = time.time()
+  async with aiohttp.ClientSession() as session:
+    file_list = await fetch_file_list(session, repo_id, revision)
+    filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, key=lambda x: x["path"]))
+    file_progress: Dict[str, RepoFileProgressEvent] = {}
+    def on_progress_wrapper(file: dict, curr_bytes: int, total_bytes: int):
+      start_time = file_progress[file["path"]].start_time if file["path"] in file_progress else time.time()
+      downloaded_this_session = file_progress[file["path"]].downloaded_this_session + (curr_bytes - file_progress[file["path"]].downloaded) if file["path"] in file_progress else curr_bytes
+      speed = downloaded_this_session / (time.time() - start_time)
+      eta = timedelta(seconds=(total_bytes - curr_bytes) / speed)
+      file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], curr_bytes, downloaded_this_session, total_bytes, speed, eta, "in_progress", start_time)
+      on_progress.trigger_all(shard, calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time))
+      if DEBUG >= 6: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}")
+    for file in filtered_file_list:
+      downloaded_bytes = (await aios.stat(target_dir/file["path"])).st_size if await aios.path.exists(target_dir/file["path"]) else 0
+      file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], downloaded_bytes, 0, file["size"], 0, timedelta(0), "not_started" if downloaded_bytes == 0 else "complete" if downloaded_bytes == file["size"] else "in_progress", time.time())
+
+    semaphore = asyncio.Semaphore(max_parallel_downloads)
+    async def download_with_semaphore(file):
+      async with semaphore:
+        await download_file(session, repo_id, revision, file["path"], target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes))
+    if not skip_download: await asyncio.gather(*[download_with_semaphore(file) for file in filtered_file_list])
+    final_repo_progress = calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time)
+    on_progress.trigger_all(shard, final_repo_progress)
+    return target_dir, final_repo_progress
+
+def new_shard_downloader() -> ShardDownloader:
+  return SingletonShardDownloader(CachedShardDownloader(NewShardDownloader()))
+
+class SingletonShardDownloader(ShardDownloader):
+  def __init__(self, shard_downloader: ShardDownloader):
+    self.shard_downloader = shard_downloader
+    self.active_downloads: Dict[Shard, asyncio.Task] = {}
+
+  @property
+  def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
+    return self.shard_downloader.on_progress
+
+  async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
+    if shard not in self.active_downloads: self.active_downloads[shard] = asyncio.create_task(self.shard_downloader.ensure_shard(shard, inference_engine_name))
+    try: return await self.active_downloads[shard]
+    finally:
+      if shard in self.active_downloads and self.active_downloads[shard].done(): del self.active_downloads[shard]
+
+  async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
+    return await self.shard_downloader.get_shard_download_status(inference_engine_name)
+
+class CachedShardDownloader(ShardDownloader):
+  def __init__(self, shard_downloader: ShardDownloader):
+    self.shard_downloader = shard_downloader
+    self.cache: Dict[tuple[str, Shard], Path] = {}
+
+  @property
+  def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
+    return self.shard_downloader.on_progress
+
+  async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
+    if (inference_engine_name, shard) in self.cache:
+      if DEBUG >= 2: print(f"ensure_shard cache hit {shard=} for {inference_engine_name}")
+      return self.cache[(inference_engine_name, shard)]
+    if DEBUG >= 2: print(f"ensure_shard cache miss {shard=} for {inference_engine_name}")
+    target_dir = await self.shard_downloader.ensure_shard(shard, inference_engine_name)
+    self.cache[(inference_engine_name, shard)] = target_dir
+    return target_dir
+
+  async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
+    return await self.shard_downloader.get_shard_download_status(inference_engine_name)
+
+class NewShardDownloader(ShardDownloader):
+  def __init__(self):
+    self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
+
+  @property
+  def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
+    return self._on_progress
+
+  async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
+    target_dir, _ = await download_shard(shard, inference_engine_name, self.on_progress)
+    return target_dir
+
+  async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
+    if DEBUG >= 2: print("Getting shard download status for", inference_engine_name)
+    downloads = await asyncio.gather(*[download_shard(build_full_shard(model_id, inference_engine_name), inference_engine_name, self.on_progress, skip_download=True) for model_id in get_supported_models([[inference_engine_name]])], return_exceptions=True)
+    if DEBUG >= 6: print("Downloaded shards:", downloads)
+    if any(isinstance(d, Exception) for d in downloads) and DEBUG >= 1: print("Error downloading shards:", [d for d in downloads if isinstance(d, Exception)])
+    return [d for d in downloads if not isinstance(d, Exception)]
+

+ 2 - 2
exo/download/shard_download.py

@@ -27,7 +27,7 @@ class ShardDownloader(ABC):
     pass
     pass
 
 
   @abstractmethod
   @abstractmethod
-  async def get_shard_download_status(self) -> Optional[Dict[str, float]]:
+  async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
     """Get the download status of shards.
     """Get the download status of shards.
     
     
     Returns:
     Returns:
@@ -45,5 +45,5 @@ class NoopShardDownloader(ShardDownloader):
   def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
   def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
     return AsyncCallbackSystem()
     return AsyncCallbackSystem()
 
 
-  async def get_shard_download_status(self) -> Optional[Dict[str, float]]:
+  async def get_shard_download_status(self, inference_engine_name: str) -> Optional[Dict[str, float]]:
     return None
     return None

+ 15 - 0
exo/download/test_new_shard_download.py

@@ -0,0 +1,15 @@
+from exo.download.new_shard_download import download_shard, NewShardDownloader
+from exo.inference.shard import Shard
+from pathlib import Path
+import asyncio
+
+async def test_new_shard_download():
+  shard_downloader = NewShardDownloader()
+  shard_downloader.on_progress.register("test").on_next(lambda shard, event: print(shard, event))
+  await shard_downloader.ensure_shard(Shard(model_id="llama-3.2-1b", start_layer=0, end_layer=0, n_layers=16), "MLXDynamicShardInferenceEngine")
+  download_statuses = await shard_downloader.get_shard_download_status("MLXDynamicShardInferenceEngine")
+  print({k: v for k, v in download_statuses if v.downloaded_bytes > 0})
+
+if __name__ == "__main__":
+  asyncio.run(test_new_shard_download())
+

+ 2 - 2
exo/inference/mlx/test_non_blocking.py

@@ -2,7 +2,7 @@ import asyncio
 import time
 import time
 import numpy as np
 import numpy as np
 from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
 from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
-from exo.download.hf.hf_shard_download import HFShardDownloader
+from exo.download.new_shard_download import NewShardDownloader
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 from exo.models import build_base_shard
 from exo.models import build_base_shard
 from collections import deque
 from collections import deque
@@ -10,7 +10,7 @@ from statistics import mean, median
 
 
 async def test_non_blocking():
 async def test_non_blocking():
     # Setup
     # Setup
-    shard_downloader = HFShardDownloader()
+    shard_downloader = NewShardDownloader()
     engine = MLXDynamicShardInferenceEngine(shard_downloader)
     engine = MLXDynamicShardInferenceEngine(shard_downloader)
     _shard = build_base_shard("llama-3.1-8b", "MLXDynamicShardInferenceEngine")
     _shard = build_base_shard("llama-3.1-8b", "MLXDynamicShardInferenceEngine")
     shard = Shard(_shard.model_id, _shard.start_layer, _shard.n_layers - 1, _shard.n_layers)
     shard = Shard(_shard.model_id, _shard.start_layer, _shard.n_layers - 1, _shard.n_layers)

+ 3 - 5
exo/inference/test_inference_engine.py

@@ -1,6 +1,6 @@
 from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
 from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
-from exo.download.hf.hf_shard_download import HFShardDownloader
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.inference_engine import InferenceEngine
+from exo.download.new_shard_download import NewShardDownloader
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 from exo.helpers import DEBUG
 from exo.helpers import DEBUG
 import os
 import os
@@ -44,13 +44,11 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
   assert np.array_equal(next_resp_full, resp4)
   assert np.array_equal(next_resp_full, resp4)
 
 
 
 
-asyncio.run(test_inference_engine(MLXDynamicShardInferenceEngine(HFShardDownloader()), MLXDynamicShardInferenceEngine(HFShardDownloader()), "llama-3.2-1b", 16))
+asyncio.run(test_inference_engine(MLXDynamicShardInferenceEngine(NewShardDownloader()), MLXDynamicShardInferenceEngine(NewShardDownloader()), "llama-3.2-1b", 16))
 
 
 if os.getenv("RUN_TINYGRAD", default="0") == "1":
 if os.getenv("RUN_TINYGRAD", default="0") == "1":
   import tinygrad
   import tinygrad
   import os
   import os
   from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
   from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
   tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
   tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
-  asyncio.run(
-    test_inference_engine(TinygradDynamicShardInferenceEngine(HFShardDownloader()), TinygradDynamicShardInferenceEngine(HFShardDownloader()), "llama-3-8b", 32)
-  )
+  asyncio.run(test_inference_engine(TinygradDynamicShardInferenceEngine(NewShardDownloader()), TinygradDynamicShardInferenceEngine(NewShardDownloader()), "llama-3.2-1b", 32))

+ 16 - 17
exo/inference/tokenizers.py

@@ -1,12 +1,11 @@
 import traceback
 import traceback
-from aiofiles import os as aios
 from os import PathLike
 from os import PathLike
-from pathlib import Path
+from aiofiles import os as aios
 from typing import Union
 from typing import Union
 from transformers import AutoTokenizer, AutoProcessor
 from transformers import AutoTokenizer, AutoProcessor
 import numpy as np
 import numpy as np
-from exo.download.hf.hf_helpers import get_local_snapshot_dir
 from exo.helpers import DEBUG
 from exo.helpers import DEBUG
+from exo.download.new_shard_download import ensure_downloads_dir
 
 
 
 
 class DummyTokenizer:
 class DummyTokenizer:
@@ -24,25 +23,25 @@ class DummyTokenizer:
     return "dummy" * len(tokens)
     return "dummy" * len(tokens)
 
 
 
 
-async def resolve_tokenizer(model_id: str):
-  if model_id == "dummy":
+async def resolve_tokenizer(repo_id: Union[str, PathLike]):
+  if repo_id == "dummy":
     return DummyTokenizer()
     return DummyTokenizer()
-  local_path = await get_local_snapshot_dir(model_id)
+  local_path = await ensure_downloads_dir()/str(repo_id).replace("/", "--")
   if DEBUG >= 2: print(f"Checking if local path exists to load tokenizer from local {local_path=}")
   if DEBUG >= 2: print(f"Checking if local path exists to load tokenizer from local {local_path=}")
   try:
   try:
     if local_path and await aios.path.exists(local_path):
     if local_path and await aios.path.exists(local_path):
-      if DEBUG >= 2: print(f"Resolving tokenizer for {model_id=} from {local_path=}")
+      if DEBUG >= 2: print(f"Resolving tokenizer for {repo_id=} from {local_path=}")
       return await _resolve_tokenizer(local_path)
       return await _resolve_tokenizer(local_path)
   except:
   except:
-    if DEBUG >= 5: print(f"Local check for {local_path=} failed. Resolving tokenizer for {model_id=} normally...")
+    if DEBUG >= 5: print(f"Local check for {local_path=} failed. Resolving tokenizer for {repo_id=} normally...")
     if DEBUG >= 5: traceback.print_exc()
     if DEBUG >= 5: traceback.print_exc()
-  return await _resolve_tokenizer(model_id)
+  return await _resolve_tokenizer(repo_id)
 
 
 
 
-async def _resolve_tokenizer(model_id_or_local_path: Union[str, PathLike]):
+async def _resolve_tokenizer(repo_id_or_local_path: Union[str, PathLike]):
   try:
   try:
-    if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id_or_local_path}")
-    processor = AutoProcessor.from_pretrained(model_id_or_local_path, use_fast=True if "Mistral-Large" in f"{model_id_or_local_path}" else False, trust_remote_code=True)
+    if DEBUG >= 4: print(f"Trying AutoProcessor for {repo_id_or_local_path}")
+    processor = AutoProcessor.from_pretrained(repo_id_or_local_path, use_fast=True if "Mistral-Large" in f"{repo_id_or_local_path}" else False, trust_remote_code=True)
     if not hasattr(processor, 'eos_token_id'):
     if not hasattr(processor, 'eos_token_id'):
       processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
       processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
     if not hasattr(processor, 'encode'):
     if not hasattr(processor, 'encode'):
@@ -51,14 +50,14 @@ async def _resolve_tokenizer(model_id_or_local_path: Union[str, PathLike]):
       processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
       processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
     return processor
     return processor
   except Exception as e:
   except Exception as e:
-    if DEBUG >= 4: print(f"Failed to load processor for {model_id_or_local_path}. Error: {e}")
+    if DEBUG >= 4: print(f"Failed to load processor for {repo_id_or_local_path}. Error: {e}")
     if DEBUG >= 4: print(traceback.format_exc())
     if DEBUG >= 4: print(traceback.format_exc())
 
 
   try:
   try:
-    if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id_or_local_path}")
-    return AutoTokenizer.from_pretrained(model_id_or_local_path, trust_remote_code=True)
+    if DEBUG >= 4: print(f"Trying AutoTokenizer for {repo_id_or_local_path}")
+    return AutoTokenizer.from_pretrained(repo_id_or_local_path, trust_remote_code=True)
   except Exception as e:
   except Exception as e:
-    if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id_or_local_path}. Falling back to tinygrad tokenizer. Error: {e}")
+    if DEBUG >= 4: print(f"Failed to load tokenizer for {repo_id_or_local_path}. Falling back to tinygrad tokenizer. Error: {e}")
     if DEBUG >= 4: print(traceback.format_exc())
     if DEBUG >= 4: print(traceback.format_exc())
 
 
-  raise ValueError(f"[TODO] Unsupported model: {model_id_or_local_path}")
+  raise ValueError(f"[TODO] Unsupported model: {repo_id_or_local_path}")

+ 12 - 13
exo/main.py

@@ -23,19 +23,18 @@ from exo.networking.tailscale.tailscale_discovery import TailscaleDiscovery
 from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
 from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 from exo.api import ChatGPTAPI
 from exo.api import ChatGPTAPI
-from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
-from exo.download.hf.hf_shard_download import HFShardDownloader
+from exo.download.shard_download import ShardDownloader, NoopShardDownloader
+from exo.download.download_progress import RepoProgressEvent
+from exo.download.new_shard_download import new_shard_downloader, has_exo_home_read_access, has_exo_home_write_access, exo_home, seed_models
 from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses_and_interfaces, terminal_link, shutdown
 from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses_and_interfaces, terminal_link, shutdown
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.models import build_base_shard, get_repo
 from exo.models import build_base_shard, get_repo
 from exo.viz.topology_viz import TopologyViz
 from exo.viz.topology_viz import TopologyViz
-from exo.download.hf.hf_helpers import has_hf_home_read_access, has_hf_home_write_access, get_hf_home, move_models_to_hf
 import uvloop
 import uvloop
 from contextlib import asynccontextmanager
 from contextlib import asynccontextmanager
 import concurrent.futures
 import concurrent.futures
-import socket
 import resource
 import resource
 import psutil
 import psutil
 
 
@@ -117,8 +116,7 @@ print_yellow_exo()
 system_info = get_system_info()
 system_info = get_system_info()
 print(f"Detected system: {system_info}")
 print(f"Detected system: {system_info}")
 
 
-shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check,
-                                                      max_parallel_downloads=args.max_parallel_downloads) if args.inference_engine != "dummy" else NoopShardDownloader()
+shard_downloader: ShardDownloader = new_shard_downloader() if args.inference_engine != "dummy" else NoopShardDownloader()
 inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
 inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
 print(f"Inference engine name after selection: {inference_engine_name}")
 print(f"Inference engine name after selection: {inference_engine_name}")
 
 
@@ -175,10 +173,10 @@ node = Node(
   None,
   None,
   inference_engine,
   inference_engine,
   discovery,
   discovery,
+  shard_downloader,
   partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
   partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
   max_generate_tokens=args.max_generate_tokens,
   max_generate_tokens=args.max_generate_tokens,
   topology_viz=topology_viz,
   topology_viz=topology_viz,
-  shard_downloader=shard_downloader,
   default_sample_temperature=args.default_temp
   default_sample_temperature=args.default_temp
 )
 )
 server = GRPCServer(node, args.node_host, args.node_port)
 server = GRPCServer(node, args.node_host, args.node_port)
@@ -223,6 +221,7 @@ last_broadcast_time = 0
 def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
 def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
   global last_broadcast_time
   global last_broadcast_time
   current_time = time.time()
   current_time = time.time()
+  if event.status == "not_started": return
   if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
   if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
     last_broadcast_time = current_time
     last_broadcast_time = current_time
     asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
     asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
@@ -322,13 +321,13 @@ async def train_model_cli(node: Node, inference_engine: InferenceEngine, model_n
 async def main():
 async def main():
   loop = asyncio.get_running_loop()
   loop = asyncio.get_running_loop()
 
 
-  # Check HuggingFace directory permissions
-  hf_home, has_read, has_write = get_hf_home(), await has_hf_home_read_access(), await has_hf_home_write_access()
-  if DEBUG >= 1: print(f"Model storage directory: {hf_home}")
+  # Check exo directory permissions
+  home, has_read, has_write = exo_home(), await has_exo_home_read_access(), await has_exo_home_write_access()
+  if DEBUG >= 1: print(f"exo home directory: {home}")
   print(f"{has_read=}, {has_write=}")
   print(f"{has_read=}, {has_write=}")
   if not has_read or not has_write:
   if not has_read or not has_write:
     print(f"""
     print(f"""
-          WARNING: Limited permissions for model storage directory: {hf_home}.
+          WARNING: Limited permissions for exo home directory: {home}.
           This may prevent model downloads from working correctly.
           This may prevent model downloads from working correctly.
           {"❌ No read access" if not has_read else ""}
           {"❌ No read access" if not has_read else ""}
           {"❌ No write access" if not has_write else ""}
           {"❌ No write access" if not has_write else ""}
@@ -337,9 +336,9 @@ async def main():
   if not args.models_seed_dir is None:
   if not args.models_seed_dir is None:
     try:
     try:
       models_seed_dir = clean_path(args.models_seed_dir)
       models_seed_dir = clean_path(args.models_seed_dir)
-      await move_models_to_hf(models_seed_dir)
+      await seed_models(models_seed_dir)
     except Exception as e:
     except Exception as e:
-      print(f"Error moving models to .cache/huggingface: {e}")
+      print(f"Error seeding models: {e}")
 
 
   def restore_cursor():
   def restore_cursor():
     if platform.system() != "Windows":
     if platform.system() != "Windows":

+ 12 - 1
exo/models.py

@@ -175,8 +175,11 @@ pretty_name = {
   "deepseek-coder-v2-lite": "Deepseek Coder V2 Lite",
   "deepseek-coder-v2-lite": "Deepseek Coder V2 Lite",
   "deepseek-coder-v2.5": "Deepseek Coder V2.5",
   "deepseek-coder-v2.5": "Deepseek Coder V2.5",
   "deepseek-v3": "Deepseek V3",
   "deepseek-v3": "Deepseek V3",
+  "deepseek-v3-3bit": "Deepseek V3 (3-bit)",
   "deepseek-r1": "Deepseek R1",
   "deepseek-r1": "Deepseek R1",
+  "deepseek-r1-3bit": "Deepseek R1 (3-bit)",
   "llava-1.5-7b-hf": "LLaVa 1.5 7B (Vision Model)",
   "llava-1.5-7b-hf": "LLaVa 1.5 7B (Vision Model)",
+  "qwen-2.5-0.5b": "Qwen 2.5 0.5B",
   "qwen-2.5-1.5b": "Qwen 2.5 1.5B",
   "qwen-2.5-1.5b": "Qwen 2.5 1.5B",
   "qwen-2.5-coder-1.5b": "Qwen 2.5 Coder 1.5B",
   "qwen-2.5-coder-1.5b": "Qwen 2.5 Coder 1.5B",
   "qwen-2.5-3b": "Qwen 2.5 3B",
   "qwen-2.5-3b": "Qwen 2.5 3B",
@@ -232,6 +235,9 @@ pretty_name = {
 def get_repo(model_id: str, inference_engine_classname: str) -> Optional[str]:
 def get_repo(model_id: str, inference_engine_classname: str) -> Optional[str]:
   return model_cards.get(model_id, {}).get("repo", {}).get(inference_engine_classname, None)
   return model_cards.get(model_id, {}).get("repo", {}).get(inference_engine_classname, None)
 
 
+def get_pretty_name(model_id: str) -> Optional[str]:
+  return pretty_name.get(model_id, None)
+
 def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional[Shard]:
 def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional[Shard]:
   repo = get_repo(model_id, inference_engine_classname)
   repo = get_repo(model_id, inference_engine_classname)
   n_layers = model_cards.get(model_id, {}).get("layers", 0)
   n_layers = model_cards.get(model_id, {}).get("layers", 0)
@@ -239,7 +245,12 @@ def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional
     return None
     return None
   return Shard(model_id, 0, 0, n_layers)
   return Shard(model_id, 0, 0, n_layers)
 
 
-def get_supported_models(supported_inference_engine_lists: List[List[str]]) -> List[str]:
+def build_full_shard(model_id: str, inference_engine_classname: str) -> Optional[Shard]:
+  base_shard = build_base_shard(model_id, inference_engine_classname)
+  if base_shard is None: return None
+  return Shard(base_shard.model_id, 0, base_shard.n_layers - 1, base_shard.n_layers)
+
+def get_supported_models(supported_inference_engine_lists: Optional[List[List[str]]] = None) -> List[str]:
   if not supported_inference_engine_lists:
   if not supported_inference_engine_lists:
     return list(model_cards.keys())
     return list(model_cards.keys())
 
 

+ 4 - 4
exo/orchestration/node.py

@@ -13,9 +13,9 @@ from exo.topology.partitioning_strategy import Partition, PartitioningStrategy,
 from exo import DEBUG
 from exo import DEBUG
 from exo.helpers import AsyncCallbackSystem
 from exo.helpers import AsyncCallbackSystem
 from exo.viz.topology_viz import TopologyViz
 from exo.viz.topology_viz import TopologyViz
-from exo.download.hf.hf_helpers import RepoProgressEvent
+from exo.download.download_progress import RepoProgressEvent
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
-from exo.download.hf.hf_shard_download import HFShardDownloader
+from exo.download.shard_download import ShardDownloader
 
 
 class Node:
 class Node:
   def __init__(
   def __init__(
@@ -24,16 +24,17 @@ class Node:
     server: Server,
     server: Server,
     inference_engine: InferenceEngine,
     inference_engine: InferenceEngine,
     discovery: Discovery,
     discovery: Discovery,
+    shard_downloader: ShardDownloader,
     partitioning_strategy: PartitioningStrategy = None,
     partitioning_strategy: PartitioningStrategy = None,
     max_generate_tokens: int = 1024,
     max_generate_tokens: int = 1024,
     default_sample_temperature: float = 0.0,
     default_sample_temperature: float = 0.0,
     topology_viz: Optional[TopologyViz] = None,
     topology_viz: Optional[TopologyViz] = None,
-    shard_downloader: Optional[HFShardDownloader] = None,
   ):
   ):
     self.id = _id
     self.id = _id
     self.inference_engine = inference_engine
     self.inference_engine = inference_engine
     self.server = server
     self.server = server
     self.discovery = discovery
     self.discovery = discovery
+    self.shard_downloader = shard_downloader
     self.partitioning_strategy = partitioning_strategy
     self.partitioning_strategy = partitioning_strategy
     self.peers: List[PeerHandle] = {}
     self.peers: List[PeerHandle] = {}
     self.topology: Topology = Topology()
     self.topology: Topology = Topology()
@@ -52,7 +53,6 @@ class Node:
     self._on_opaque_status.register("node_status").on_next(self.on_node_status)
     self._on_opaque_status.register("node_status").on_next(self.on_node_status)
     self.node_download_progress: Dict[str, RepoProgressEvent] = {}
     self.node_download_progress: Dict[str, RepoProgressEvent] = {}
     self.topology_inference_engines_pool: List[List[str]] = []
     self.topology_inference_engines_pool: List[List[str]] = []
-    self.shard_downloader = shard_downloader
     self.outstanding_requests = {}
     self.outstanding_requests = {}
 
 
   async def start(self, wait_for_peers: int = 0) -> None:
   async def start(self, wait_for_peers: int = 0) -> None:

+ 2 - 2
exo/orchestration/test_node.py

@@ -5,7 +5,7 @@ import pytest
 
 
 from .node import Node
 from .node import Node
 from exo.networking.peer_handle import PeerHandle
 from exo.networking.peer_handle import PeerHandle
-
+from exo.download.shard_download import NoopShardDownloader
 
 
 class TestNode(unittest.IsolatedAsyncioTestCase):
 class TestNode(unittest.IsolatedAsyncioTestCase):
   def setUp(self):
   def setUp(self):
@@ -22,7 +22,7 @@ class TestNode(unittest.IsolatedAsyncioTestCase):
     mock_peer2.id.return_value = "peer2"
     mock_peer2.id.return_value = "peer2"
     self.mock_discovery.discover_peers = AsyncMock(return_value=[mock_peer1, mock_peer2])
     self.mock_discovery.discover_peers = AsyncMock(return_value=[mock_peer1, mock_peer2])
 
 
-    self.node = Node("test_node", self.mock_server, self.mock_inference_engine, "localhost", 50051, self.mock_discovery)
+    self.node = Node("test_node", self.mock_server, self.mock_inference_engine, "localhost", 50051, self.mock_discovery, NoopShardDownloader())
 
 
   async def asyncSetUp(self):
   async def asyncSetUp(self):
     await self.node.start()
     await self.node.start()

+ 1 - 1
exo/tinychat/index.css

@@ -843,4 +843,4 @@ main {
   font-size: 0.8em;
   font-size: 0.8em;
   color: var(--secondary-color-transparent);
   color: var(--secondary-color-transparent);
   font-family: monospace;
   font-family: monospace;
-}
+}

+ 8 - 5
exo/tinychat/index.js

@@ -75,12 +75,12 @@ document.addEventListener("alpine:init", () => {
       while (true) {
       while (true) {
         try {
         try {
           await this.populateSelector();
           await this.populateSelector();
-          // Wait 5 seconds before next poll
-          await new Promise(resolve => setTimeout(resolve, 5000));
+          // Wait 15 seconds before next poll
+          await new Promise(resolve => setTimeout(resolve, 15000));
         } catch (error) {
         } catch (error) {
           console.error('Model polling error:', error);
           console.error('Model polling error:', error);
           // If there's an error, wait before retrying
           // If there's an error, wait before retrying
-          await new Promise(resolve => setTimeout(resolve, 5000));
+          await new Promise(resolve => setTimeout(resolve, 15000));
         }
         }
       }
       }
     },
     },
@@ -637,6 +637,9 @@ document.addEventListener("alpine:init", () => {
       const vizElement = this.$refs.topologyViz;
       const vizElement = this.$refs.topologyViz;
       vizElement.innerHTML = ''; // Clear existing visualization
       vizElement.innerHTML = ''; // Clear existing visualization
 
 
+      // Helper function to truncate node ID
+      const truncateNodeId = (id) => id.substring(0, 8);
+
       // Create nodes from object
       // Create nodes from object
       Object.entries(topologyData.nodes).forEach(([nodeId, node]) => {
       Object.entries(topologyData.nodes).forEach(([nodeId, node]) => {
         const nodeElement = document.createElement('div');
         const nodeElement = document.createElement('div');
@@ -647,14 +650,14 @@ document.addEventListener("alpine:init", () => {
         const peerConnectionsHtml = peerConnections.map(peer => `
         const peerConnectionsHtml = peerConnections.map(peer => `
           <div class="peer-connection">
           <div class="peer-connection">
             <i class="fas fa-arrow-right"></i>
             <i class="fas fa-arrow-right"></i>
-            <span>To ${peer.to_id}: ${peer.description}</span>
+            <span>To ${truncateNodeId(peer.to_id)}: ${peer.description}</span>
           </div>
           </div>
         `).join('');
         `).join('');
 
 
         nodeElement.innerHTML = `
         nodeElement.innerHTML = `
           <div class="node-info">
           <div class="node-info">
             <span class="status ${nodeId === topologyData.active_node_id ? 'active' : 'inactive'}"></span>
             <span class="status ${nodeId === topologyData.active_node_id ? 'active' : 'inactive'}"></span>
-            <span>${node.model}</span>
+            <span>${node.model} [${truncateNodeId(nodeId)}]</span>
           </div>
           </div>
           <div class="node-details">
           <div class="node-details">
             <span>${node.chip}</span>
             <span>${node.chip}</span>

+ 1 - 1
exo/viz/test_topology_viz.py

@@ -5,7 +5,7 @@ from exo.viz.topology_viz import TopologyViz
 from exo.topology.topology import Topology
 from exo.topology.topology import Topology
 from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
 from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
 from exo.topology.partitioning_strategy import Partition
 from exo.topology.partitioning_strategy import Partition
-from exo.download.hf.hf_helpers import RepoProgressEvent, RepoFileProgressEvent
+from exo.download.download_progress import RepoProgressEvent
 
 
 
 
 def create_hf_repo_progress_event(
 def create_hf_repo_progress_event(

+ 1 - 1
exo/viz/topology_viz.py

@@ -4,7 +4,7 @@ from typing import List, Optional, Tuple, Dict
 from exo.helpers import exo_text, pretty_print_bytes, pretty_print_bytes_per_second
 from exo.helpers import exo_text, pretty_print_bytes, pretty_print_bytes_per_second
 from exo.topology.topology import Topology
 from exo.topology.topology import Topology
 from exo.topology.partitioning_strategy import Partition
 from exo.topology.partitioning_strategy import Partition
-from exo.download.hf.hf_helpers import RepoProgressEvent
+from exo.download.download_progress import RepoProgressEvent
 from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
 from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
 from rich.console import Console, Group
 from rich.console import Console, Group
 from rich.text import Text
 from rich.text import Text

+ 0 - 50
extra/download_hf.py

@@ -1,50 +0,0 @@
-import argparse
-import asyncio
-from exo.download.hf.hf_helpers import download_all_files, RepoProgressEvent
-
-DEFAULT_ALLOW_PATTERNS = [
-  "*.json",
-  "*.py",
-  "tokenizer.model",
-  "*.tiktoken",
-  "*.txt",
-  "*.safetensors",
-]
-# Always ignore `.git` and `.cache/huggingface` folders in commits
-DEFAULT_IGNORE_PATTERNS = [
-  ".git",
-  ".git/*",
-  "*/.git",
-  "**/.git/**",
-  ".cache/huggingface",
-  ".cache/huggingface/*",
-  "*/.cache/huggingface",
-  "**/.cache/huggingface/**",
-]
-
-
-async def main(repo_id, revision="main", allow_patterns=None, ignore_patterns=None):
-  async def progress_callback(event: RepoProgressEvent):
-    print(f"Overall Progress: {event.completed_files}/{event.total_files} files, {event.downloaded_bytes}/{event.total_bytes} bytes")
-    print(f"Estimated time remaining: {event.overall_eta}")
-    print("File Progress:")
-    for file_path, progress in event.file_progress.items():
-      status_icon = {'not_started': '⚪', 'in_progress': '🔵', 'complete': '✅'}[progress.status]
-      eta_str = str(progress.eta)
-      print(f"{status_icon} {file_path}: {progress.downloaded}/{progress.total} bytes, "
-            f"Speed: {progress.speed:.2f} B/s, ETA: {eta_str}")
-    print("\n")
-
-  await download_all_files(repo_id, revision, progress_callback, allow_patterns, ignore_patterns)
-
-
-if __name__ == "__main__":
-  parser = argparse.ArgumentParser(description="Download files from a Hugging Face model repository.")
-  parser.add_argument("--repo-id", required=True, help="The repository ID (e.g., 'meta-llama/Meta-Llama-3.1-8B-Instruct')")
-  parser.add_argument("--revision", default="main", help="The revision to download (branch, tag, or commit hash)")
-  parser.add_argument("--allow-patterns", nargs="*", default=None, help="Patterns of files to allow (e.g., '*.json' '*.safetensors')")
-  parser.add_argument("--ignore-patterns", nargs="*", default=None, help="Patterns of files to ignore (e.g., '.*')")
-
-  args = parser.parse_args()
-
-  asyncio.run(main(args.repo_id, args.revision, args.allow_patterns, args.ignore_patterns))

+ 0 - 26
test/test_hf.py

@@ -1,26 +0,0 @@
-import os
-import sys
-
-# Add the project root to the Python path
-project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
-sys.path.insert(0, project_root)
-
-import asyncio
-from exo.download.hf.hf_helpers import get_weight_map
-
-async def test_get_weight_map():
-  repo_ids = [
-    "mlx-community/quantized-gemma-2b",
-    "mlx-community/Meta-Llama-3.1-8B-4bit",
-    "mlx-community/Meta-Llama-3.1-70B-4bit",
-    "mlx-community/Meta-Llama-3.1-405B-4bit",
-  ]
-  for repo_id in repo_ids:
-    weight_map = await get_weight_map(repo_id)
-    assert weight_map is not None, "Weight map should not be None"
-    assert isinstance(weight_map, dict), "Weight map should be a dictionary"
-    assert len(weight_map) > 0, "Weight map should not be empty"
-    print(f"OK: {repo_id}")
-
-if __name__ == "__main__":
-  asyncio.run(test_get_weight_map())