Parcourir la source

remove a lot of hf bloat

Alex Cheema il y a 3 mois
Parent
commit
1df023023e

+ 11 - 35
exo/api/chatgpt_api.py

@@ -21,19 +21,17 @@ import numpy as np
 import base64
 from io import BytesIO
 import platform
-from exo.download.shard_download import RepoProgressEvent
+from exo.download.download_progress import RepoProgressEvent
+from exo.download.new_shard_download import ensure_downloads_dir, delete_model
+import tempfile
+from exo.apputil import create_animation_mp4
+from collections import defaultdict
 
 if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
   import mlx.core as mx
 else:
   import numpy as mx
 
-import tempfile
-import shutil
-from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
-from exo.apputil import create_animation_mp4
-from collections import defaultdict
-
 
 class Message:
   def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]], tools: Optional[List[Dict]] = None):
@@ -545,35 +543,13 @@ class ChatGPTAPI:
       return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
 
   async def handle_delete_model(self, request):
+    model_id = request.match_info.get('model_name')
     try:
-      model_name = request.match_info.get('model_name')
-      if DEBUG >= 2: print(f"Attempting to delete model: {model_name}")
-
-      if not model_name or model_name not in model_cards:
-        return web.json_response({"detail": f"Invalid model name: {model_name}"}, status=400)
-
-      shard = build_base_shard(model_name, self.inference_engine_classname)
-      if not shard:
-        return web.json_response({"detail": "Could not build shard for model"}, status=400)
-
-      repo_id = get_repo(shard.model_id, self.inference_engine_classname)
-      if DEBUG >= 2: print(f"Repo ID for model: {repo_id}")
-
-      # Get the HF cache directory using the helper function
-      hf_home = get_hf_home()
-      cache_dir = get_repo_root(repo_id)
-
-      if DEBUG >= 2: print(f"Looking for model files in: {cache_dir}")
-
-      if os.path.exists(cache_dir):
-        if DEBUG >= 2: print(f"Found model files at {cache_dir}, deleting...")
-        try:
-          shutil.rmtree(cache_dir)
-          return web.json_response({"status": "success", "message": f"Model {model_name} deleted successfully", "path": str(cache_dir)})
-        except Exception as e:
-          return web.json_response({"detail": f"Failed to delete model files: {str(e)}"}, status=500)
-      else:
-        return web.json_response({"detail": f"Model files not found at {cache_dir}"}, status=404)
+      if await delete_model(model_id, self.inference_engine_classname): return web.json_response({"status": "success", "message": f"Model {model_id} deleted successfully"})
+      else: return web.json_response({"detail": f"Model {model_id} files not found"}, status=404)
+    except Exception as e:
+      if DEBUG >= 2: traceback.print_exc()
+      return web.json_response({"detail": f"Error deleting model: {str(e)}"}, status=500)
 
     except Exception as e:
       print(f"Error in handle_delete_model: {str(e)}")

+ 3 - 411
exo/download/hf/hf_helpers.py

@@ -1,36 +1,16 @@
 import aiofiles.os as aios
 from typing import Union
-import asyncio
-import aiohttp
-import json
 import os
-import sys
-import shutil
-from urllib.parse import urljoin
-from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
-from datetime import datetime, timedelta
+from typing import Callable, Optional, Dict, List, Union
 from fnmatch import fnmatch
 from pathlib import Path
-from typing import Generator, Iterable, TypeVar, TypedDict
-from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
-from exo.helpers import DEBUG, is_frozen
-from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback
+from typing import Generator, Iterable, TypeVar
+from exo.helpers import DEBUG
 from exo.inference.shard import Shard
 import aiofiles
 
 T = TypeVar("T")
 
-async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]:
-  refs_dir = get_repo_root(repo_id)/"refs"
-  refs_file = refs_dir/revision
-  if await aios.path.exists(refs_file):
-    async with aiofiles.open(refs_file, 'r') as f:
-      commit_hash = (await f.read()).strip()
-      snapshot_dir = get_repo_root(repo_id)/"snapshots"/commit_hash
-      return snapshot_dir
-  return None
-
-
 def filter_repo_objects(
   items: Iterable[T],
   *,
@@ -48,14 +28,12 @@ def filter_repo_objects(
     ignore_patterns = [_add_wildcard_to_directories(p) for p in ignore_patterns]
 
   if key is None:
-
     def _identity(item: T) -> str:
       if isinstance(item, str):
         return item
       if isinstance(item, Path):
         return str(item)
       raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.")
-
     key = _identity
 
   for item in items:
@@ -66,22 +44,18 @@ def filter_repo_objects(
       continue
     yield item
 
-
 def _add_wildcard_to_directories(pattern: str) -> str:
   if pattern[-1] == "/":
     return pattern + "*"
   return pattern
 
-
 def get_hf_endpoint() -> str:
   return os.environ.get('HF_ENDPOINT', "https://huggingface.co")
 
-
 def get_hf_home() -> Path:
   """Get the Hugging Face home directory."""
   return Path(os.environ.get("HF_HOME", Path.home()/".cache"/"huggingface"))
 
-
 async def get_hf_token():
   """Retrieve the Hugging Face token from the user's HF_HOME directory."""
   token_path = get_hf_home()/"token"
@@ -90,7 +64,6 @@ async def get_hf_token():
       return (await f.read()).strip()
   return None
 
-
 async def get_auth_headers():
   """Get authentication headers if a token is available."""
   token = await get_hf_token()
@@ -98,324 +71,6 @@ async def get_auth_headers():
     return {"Authorization": f"Bearer {token}"}
   return {}
 
-
-def get_repo_root(repo_id: str) -> Path:
-  """Get the root directory for a given repo ID in the Hugging Face cache."""
-  sanitized_repo_id = str(repo_id).replace("/", "--")
-  return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"
-
-async def move_models_to_hf(seed_dir: Union[str, Path]):
-  """Move model in resources folder of app to .cache/huggingface/hub"""
-  source_dir = Path(seed_dir)
-  dest_dir = get_hf_home()/"hub"
-  await aios.makedirs(dest_dir, exist_ok=True)  
-  for path in source_dir.iterdir():
-    if path.is_dir() and path.name.startswith("models--"):
-      dest_path = dest_dir / path.name
-      if await aios.path.exists(dest_path):
-        print('Skipping moving model to .cache directory')
-      else:
-        try:
-          await aios.rename(str(path), str(dest_path))
-        except Exception as e:
-          print(f'Error moving model to .cache: {e}')
-    
-    
-    
-async def fetch_file_list(session, repo_id, revision, path=""):
-  api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
-  url = f"{api_url}/{path}" if path else api_url
-
-  headers = await get_auth_headers()
-  async with session.get(url, headers=headers) as response:
-    if response.status == 200:
-      data = await response.json()
-      files = []
-      for item in data:
-        if item["type"] == "file":
-          files.append({"path": item["path"], "size": item["size"]})
-        elif item["type"] == "directory":
-          subfiles = await fetch_file_list(session, repo_id, revision, item["path"])
-          files.extend(subfiles)
-      return files
-    else:
-      raise Exception(f"Failed to fetch file list: {response.status}")
-
-
-@retry(
-  stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=60), retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError, aiohttp.ClientResponseError)), reraise=True
-)
-async def download_file(
-  session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[RepoFileProgressCallback] = None, use_range_request: bool = True
-):
-  base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
-  url = urljoin(base_url, file_path)
-  local_path = os.path.join(save_directory, file_path)
-
-  await aios.makedirs(os.path.dirname(local_path), exist_ok=True)
-
-  # Check if file already exists and get its size
-  local_file_size = await aios.path.getsize(local_path) if await aios.path.exists(local_path) else 0
-
-  headers = await get_auth_headers()
-  if use_range_request:
-    headers["Range"] = f"bytes={local_file_size}-"
-
-  async with session.get(url, headers=headers) as response:
-    total_size = int(response.headers.get('Content-Length', 0))
-    downloaded_size = local_file_size
-    downloaded_this_session = 0
-    mode = 'ab' if use_range_request else 'wb'
-    percentage = await get_file_download_percentage(
-      session,
-      repo_id,
-      revision,
-      file_path,
-      Path(save_directory)
-    )
-    
-    if percentage == 100:
-      if DEBUG >= 2: print(f"File already downloaded: {file_path}")
-      if progress_callback:
-        await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, total_size, 0, total_size, 0, timedelta(0), "complete"))
-      return
-
-    if response.status == 200:
-      # File doesn't support range requests or we're not using them, start from beginning
-      mode = 'wb'
-      downloaded_size = 0
-    elif response.status == 206:
-      # Partial content, resume download
-      content_range = response.headers.get('Content-Range', '')
-      try:
-        total_size = int(content_range.split('/')[-1])
-      except ValueError:
-        if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
-        return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
-    elif response.status == 416:
-      # Range not satisfiable, get the actual file size
-      content_range = response.headers.get('Content-Range', '')
-      try:
-        total_size = int(content_range.split('/')[-1])
-        if downloaded_size == total_size:
-          if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}")
-          if progress_callback:
-            await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
-          return
-      except ValueError:
-        if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
-        return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
-    else:
-      raise aiohttp.ClientResponseError(response.request_info, response.history, status=response.status, message=f"Failed to download {file_path}: {response.status}")
-
-    if downloaded_size == total_size:
-      if progress_callback:
-        await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
-      return
-
-    DOWNLOAD_CHUNK_SIZE = 32768
-    start_time = datetime.now()
-    async with aiofiles.open(local_path, mode) as f:
-      async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
-        await f.write(chunk)
-        downloaded_size += len(chunk)
-        downloaded_this_session += len(chunk)
-        if progress_callback and total_size:
-          elapsed_time = (datetime.now() - start_time).total_seconds()
-          speed = int(downloaded_this_session/elapsed_time) if elapsed_time > 0 else 0
-          remaining_size = total_size - downloaded_size
-          eta = timedelta(seconds=remaining_size/speed) if speed > 0 else timedelta(0)
-          status = "in_progress" if downloaded_size < total_size else "complete"
-          if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}")
-          await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
-    if DEBUG >= 2: print(f"Downloaded: {file_path}")
-
-
-async def resolve_revision_to_commit_hash(repo_id: str, revision: str) -> str:
-  repo_root = get_repo_root(repo_id)
-  refs_dir = repo_root/"refs"
-  refs_file = refs_dir/revision
-
-  # Check if we have a cached commit hash
-  if await aios.path.exists(refs_file):
-    async with aiofiles.open(refs_file, 'r') as f:
-      commit_hash = (await f.read()).strip()
-      if DEBUG >= 2: print(f"Commit hash is already cached at {refs_file}: {commit_hash}")
-      return commit_hash
-
-  # Fetch the commit hash for the given revision
-  async with aiohttp.ClientSession() as session:
-    api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/revision/{revision}"
-    headers = await get_auth_headers()
-    async with session.get(api_url, headers=headers) as response:
-      if response.status != 200:
-        raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
-      revision_info = await response.json()
-      commit_hash = revision_info['sha']
-
-  # Cache the commit hash
-  await aios.makedirs(refs_dir, exist_ok=True)
-  async with aiofiles.open(refs_file, 'w') as f:
-    await f.write(commit_hash)
-
-  return commit_hash
-
-
-async def download_repo_files(
-  repo_id: str,
-  revision: str = "main",
-  progress_callback: Optional[RepoProgressCallback] = None,
-  allow_patterns: Optional[Union[List[str], str]] = None,
-  ignore_patterns: Optional[Union[List[str], str]] = None,
-  max_parallel_downloads: int = 4
-) -> Path:
-  repo_root = get_repo_root(repo_id)
-  snapshots_dir = repo_root/"snapshots"
-  cachedreqs_dir = repo_root/"cachedreqs"
-
-  # Ensure directories exist
-  await aios.makedirs(snapshots_dir, exist_ok=True)
-  await aios.makedirs(cachedreqs_dir, exist_ok=True)
-
-  # Resolve revision to commit hash
-  commit_hash = await resolve_revision_to_commit_hash(repo_id, revision)
-
-  # Set up the snapshot directory
-  snapshot_dir = snapshots_dir/commit_hash
-  await aios.makedirs(snapshot_dir, exist_ok=True)
-
-  # Set up the cached file list directory
-  cached_file_list_dir = cachedreqs_dir/commit_hash
-  await aios.makedirs(cached_file_list_dir, exist_ok=True)
-  cached_file_list_path = cached_file_list_dir/"fetch_file_list.json"
-
-  async with aiohttp.ClientSession() as session:
-    # Check if we have a cached file list
-    if await aios.path.exists(cached_file_list_path):
-      async with aiofiles.open(cached_file_list_path, 'r') as f:
-        file_list = json.loads(await f.read())
-      if DEBUG >= 2: print(f"Using cached file list from {cached_file_list_path}")
-    else:
-      file_list = await fetch_file_list(session, repo_id, revision)
-      # Cache the file list
-      async with aiofiles.open(cached_file_list_path, 'w') as f:
-        await f.write(json.dumps(file_list))
-      if DEBUG >= 2: print(f"Cached file list at {cached_file_list_path}")
-
-    model_index_exists = any(file["path"] == "model_index.json" for file in file_list)
-    if model_index_exists:
-      allow_patterns = ["**/*.json", "**/*.txt", "**/*model.safetensors", "*.json"]
-
-    filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
-    total_files = len(filtered_file_list)
-    total_bytes = sum(file["size"] for file in filtered_file_list)
-    file_progress: Dict[str, RepoFileProgressEvent] = {
-      file["path"]: RepoFileProgressEvent(repo_id, revision, file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started")
-      for file in filtered_file_list
-    }
-    start_time = datetime.now()
-
-    async def download_with_progress(file_info, progress_state):
-      local_path = snapshot_dir/file_info["path"]
-      if await aios.path.exists(local_path) and (await aios.stat(local_path)).st_size == file_info["size"]:
-        if DEBUG >= 2: print(f"File already fully downloaded: {file_info['path']}")
-        progress_state['completed_files'] += 1
-        progress_state['downloaded_bytes'] += file_info["size"]
-        file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], 0, file_info["size"], 0, timedelta(0), "complete")
-        if progress_callback:
-          elapsed_time = (datetime.now() - start_time).total_seconds()
-          overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
-          remaining_bytes = total_bytes - progress_state['downloaded_bytes']
-          overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
-          status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
-          await progress_callback(
-            RepoProgressEvent(
-              repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
-              overall_eta, file_progress, status
-            )
-          )
-        return
-
-      async def file_progress_callback(event: RepoFileProgressEvent):
-        progress_state['downloaded_bytes'] += event.downloaded - file_progress[event.file_path].downloaded
-        progress_state['downloaded_bytes_this_session'] += event.downloaded_this_session - file_progress[event.file_path].downloaded_this_session
-        file_progress[event.file_path] = event
-        if progress_callback:
-          elapsed_time = (datetime.now() - start_time).total_seconds()
-          overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
-          remaining_bytes = total_bytes - progress_state['downloaded_bytes']
-          overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
-          status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete"
-          await progress_callback(
-            RepoProgressEvent(
-              repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
-              overall_eta, file_progress, status
-            )
-          )
-
-      await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
-      progress_state['completed_files'] += 1
-      file_progress[
-        file_info["path"]
-      ] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete")
-      if progress_callback:
-        elapsed_time = (datetime.now() - start_time).total_seconds()
-        overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
-        remaining_bytes = total_bytes - progress_state['downloaded_bytes']
-        overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
-        status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
-        await progress_callback(
-          RepoProgressEvent(
-            repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
-            overall_eta, file_progress, status
-          )
-        )
-
-    progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0}
-
-    semaphore = asyncio.Semaphore(max_parallel_downloads)
-
-    async def download_with_semaphore(file_info):
-      async with semaphore:
-        await download_with_progress(file_info, progress_state)
-
-    tasks = [asyncio.create_task(download_with_semaphore(file_info)) for file_info in filtered_file_list]
-    await asyncio.gather(*tasks)
-
-  return snapshot_dir
-
-
-async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[str, str]]:
-  """
-    Retrieve the weight map from the model.safetensors.index.json file.
-
-    Args:
-        repo_id (str): The Hugging Face repository ID.
-        revision (str): The revision of the repository to use.
-
-    Returns:
-        Optional[Dict[str, str]]: The weight map if it exists, otherwise None.
-    """
-
-  # Download the index file
-  await download_repo_files(repo_id=repo_id, revision=revision, allow_patterns="model.safetensors.index.json")
-
-  # Check if the file exists
-  repo_root = get_repo_root(repo_id)
-  commit_hash = await resolve_revision_to_commit_hash(repo_id, revision)
-  snapshot_dir = repo_root/"snapshots"/commit_hash
-  index_file = next((f for f in await aios.listdir(snapshot_dir) if f.endswith("model.safetensors.index.json")), None)
-
-  if index_file:
-    index_file_path = snapshot_dir/index_file
-    if await aios.path.exists(index_file_path):
-      async with aiofiles.open(index_file_path, 'r') as f:
-        index_data = json.loads(await f.read())
-      return index_data.get("weight_map")
-
-  return None
-
-
 def extract_layer_num(tensor_name: str) -> Optional[int]:
   # This is a simple example and might need to be adjusted based on the actual naming convention
   parts = tensor_name.split('.')
@@ -424,7 +79,6 @@ def extract_layer_num(tensor_name: str) -> Optional[int]:
       return int(part)
   return None
 
-
 def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
   default_patterns = set(["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"])
   shard_specific_patterns = set()
@@ -442,65 +96,3 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
     shard_specific_patterns = set(["*.safetensors"])
   if DEBUG >= 3: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
   return list(default_patterns | shard_specific_patterns)
-
-async def get_file_download_percentage(
-    session: aiohttp.ClientSession,
-    repo_id: str,
-    revision: str,
-    file_path: str,
-    snapshot_dir: Path,
-) -> float:
-  """
-    Calculate the download percentage for a file by comparing local and remote sizes.
-    """
-  try:
-    local_path = snapshot_dir / file_path
-    if not await aios.path.exists(local_path):
-      return 0
-
-    # Get local file size first
-    local_size = await aios.path.getsize(local_path)
-    if local_size == 0:
-      return 0
-
-    # Check remote size
-    base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
-    url = urljoin(base_url, file_path)
-    headers = await get_auth_headers()
-
-    # Use HEAD request with redirect following for all files
-    async with session.head(url, headers=headers, allow_redirects=True) as response:
-      if response.status != 200:
-        if DEBUG >= 2:
-          print(f"Failed to get remote file info for {file_path}: {response.status}")
-        return 0
-
-      remote_size = int(response.headers.get('Content-Length', 0))
-
-      if remote_size == 0:
-        if DEBUG >= 2:
-          print(f"Remote size is 0 for {file_path}")
-        return 0
-
-      # Only return 100% if sizes match exactly
-      if local_size == remote_size:
-        return 100.0
-
-      # Calculate percentage based on sizes
-      return (local_size / remote_size) * 100 if remote_size > 0 else 0
-
-  except Exception as e:
-    if DEBUG >= 2:
-      print(f"Error checking file download status for {file_path}: {e}")
-    return 0
-
-async def has_hf_home_read_access() -> bool:
-  hf_home = get_hf_home()
-  try: return await aios.access(hf_home, os.R_OK)
-  except OSError: return False
-
-async def has_hf_home_write_access() -> bool:
-  hf_home = get_hf_home()
-  try: return await aios.access(hf_home, os.W_OK)
-  except OSError: return False
-

+ 0 - 172
exo/download/hf/hf_shard_download.py

@@ -1,172 +0,0 @@
-import asyncio
-import traceback
-from pathlib import Path
-from typing import Dict, List, Tuple, Optional, Union
-from exo.inference.shard import Shard
-from exo.download.shard_download import ShardDownloader
-from exo.download.download_progress import RepoProgressEvent
-from exo.download.hf.hf_helpers import (
-    download_repo_files, RepoProgressEvent, get_weight_map, 
-    get_allow_patterns, get_repo_root, fetch_file_list, 
-    get_local_snapshot_dir, get_file_download_percentage,
-    filter_repo_objects
-)
-from exo.helpers import AsyncCallbackSystem, DEBUG
-from exo.models import model_cards, get_repo
-import aiohttp
-from aiofiles import os as aios
-
-
-class HFShardDownloader(ShardDownloader):
-  def __init__(self, quick_check: bool = False, max_parallel_downloads: int = 4):
-    self.quick_check = quick_check
-    self.max_parallel_downloads = max_parallel_downloads
-    self.active_downloads: Dict[Shard, asyncio.Task] = {}
-    self.completed_downloads: Dict[Shard, Path] = {}
-    self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
-    self.current_shard: Optional[Shard] = None
-    self.current_repo_id: Optional[str] = None
-    self.revision: str = "main"
-
-  async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
-    self.current_shard = shard
-    self.current_repo_id = get_repo(shard.model_id, inference_engine_name)
-    repo_name = get_repo(shard.model_id, inference_engine_name)
-    if shard in self.completed_downloads:
-      return self.completed_downloads[shard]
-    if self.quick_check:
-      repo_root = get_repo_root(repo_name)
-      snapshots_dir = repo_root/"snapshots"
-      if snapshots_dir.exists():
-        visible_dirs = [d for d in snapshots_dir.iterdir() if not d.name.startswith('.')]
-        if visible_dirs:
-          most_recent_dir = max(visible_dirs, key=lambda x: x.stat().st_mtime)
-          return most_recent_dir
-
-    # If a download on this shard is already in progress, keep that one
-    for active_shard in self.active_downloads:
-      if active_shard == shard:
-        if DEBUG >= 2: print(f"Download already in progress for {shard}. Keeping that one.")
-        return await self.active_downloads[shard]
-
-    # Cancel any downloads for this model_id on a different shard
-    existing_active_shards = [active_shard for active_shard in self.active_downloads.keys() if active_shard.model_id == shard.model_id]
-    for active_shard in existing_active_shards:
-      if DEBUG >= 2: print(f"Cancelling download for {active_shard} (replacing with {shard})")
-      task = self.active_downloads[active_shard]
-      task.cancel()
-      try:
-        await task
-      except asyncio.CancelledError:
-        pass  # This is expected when cancelling a task
-      except Exception as e:
-        if DEBUG >= 2: print(f"Error in cancelling download {active_shard}: {e}")
-        traceback.print_exc()
-    self.active_downloads = {active_shard: task for active_shard, task in self.active_downloads.items() if active_shard.model_id != shard.model_id}
-
-    # Start new download
-    download_task = asyncio.create_task(self._download_shard(shard, repo_name))
-    self.active_downloads[shard] = download_task
-    try:
-      path = await download_task
-      self.completed_downloads[shard] = path
-      return path
-    finally:
-      # Ensure the task is removed even if an exception occurs
-      print(f"Removing download task for {shard}: {shard in self.active_downloads}")
-      if shard in self.active_downloads:
-        self.active_downloads.pop(shard)
-
-  async def _download_shard(self, shard: Shard, repo_name: str) -> Path:
-    async def wrapped_progress_callback(event: RepoProgressEvent):
-      self._on_progress.trigger_all(shard, event)
-
-    weight_map = await get_weight_map(repo_name)
-    allow_patterns = get_allow_patterns(weight_map, shard)
-
-    return await download_repo_files(repo_name, progress_callback=wrapped_progress_callback, allow_patterns=allow_patterns, max_parallel_downloads=self.max_parallel_downloads)
-
-  @property
-  def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
-    return self._on_progress
-
-  async def get_shard_download_status(self) -> Optional[Dict[str, Union[float, int]]]:
-    if not self.current_shard or not self.current_repo_id:
-      if DEBUG >= 2:
-        print(f"No current shard or repo_id set: {self.current_shard=} {self.current_repo_id=}")
-      return None
-
-    try:
-      # If no snapshot directory exists, return None - no need to check remote files
-      snapshot_dir = await get_local_snapshot_dir(self.current_repo_id, self.revision)
-      if not snapshot_dir:
-        if DEBUG >= 2:
-          print(f"No snapshot directory found for {self.current_repo_id}")
-        return None
-
-      if not await aios.path.exists(snapshot_dir/"model_index.json"):
-      # Get the weight map to know what files we need
-        weight_map = await get_weight_map(self.current_repo_id, self.revision)
-        if not weight_map:
-          if DEBUG >= 2:
-            print(f"No weight map found for {self.current_repo_id}")
-          return None
-
-        # Get all files needed for this shard
-        patterns = get_allow_patterns(weight_map, self.current_shard)
-      else:
-        patterns = ["**/*.json", "**/*.txt", "**/*model.safetensors", "*.json"]
-
-
-      # Check download status for all relevant files
-      status = {}
-      total_bytes = 0
-      downloaded_bytes = 0
-
-      async with aiohttp.ClientSession() as session:
-        file_list = await fetch_file_list(session, self.current_repo_id, self.revision)
-        relevant_files = list(
-            filter_repo_objects(
-                file_list, allow_patterns=patterns, key=lambda x: x["path"]))
-
-        for file in relevant_files:
-          file_size = file["size"]
-          total_bytes += file_size
-
-          percentage = await get_file_download_percentage(
-              session,
-              self.current_repo_id,
-              self.revision,
-              file["path"],
-              snapshot_dir,
-          )
-          status[file["path"]] = percentage
-          downloaded_bytes += (file_size * (percentage / 100))
-
-        # Add overall progress weighted by file size
-        if total_bytes > 0:
-          status["overall"] = (downloaded_bytes / total_bytes) * 100
-        else:
-          status["overall"] = 0
-          
-        # Add total size in bytes
-        status["total_size"] = total_bytes
-        if status["overall"] != 100:
-          status["total_downloaded"] = downloaded_bytes
-        
-
-        if DEBUG >= 2:
-          print(f"Download calculation for {self.current_repo_id}:")
-          print(f"Total bytes: {total_bytes}")
-          print(f"Downloaded bytes: {downloaded_bytes}")
-        if DEBUG >= 3:
-          for file in relevant_files:
-            print(f"File {file['path']}: size={file['size']}, percentage={status[file['path']]}")
-
-      return status
-
-    except Exception as e:
-      if DEBUG >= 3:
-        print(f"Error getting shard download status: {e}")
-        traceback.print_exc()
-      return None

+ 34 - 3
exo/download/hf/new_shard_download.py → exo/download/new_shard_download.py

@@ -16,15 +16,46 @@ import time
 from datetime import timedelta
 import asyncio
 import json
+import traceback
+import shutil
 
 def exo_home() -> Path:
   return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo"))
 
+async def has_exo_home_read_access() -> bool:
+  try: return await aios.access(exo_home(), os.R_OK)
+  except OSError: return False
+
+async def has_exo_home_write_access() -> bool:
+  try: return await aios.access(exo_home(), os.W_OK)
+  except OSError: return False
+
 async def ensure_downloads_dir() -> Path:
   downloads_dir = exo_home()/"downloads"
   await aios.makedirs(downloads_dir, exist_ok=True)
   return downloads_dir
 
+async def delete_model(model_id: str, inference_engine_name: str) -> bool:
+  repo_id = get_repo(model_id, inference_engine_name)
+  model_dir = await ensure_downloads_dir()/repo_id.replace("/", "--")
+  if not await aios.path.exists(model_dir): return False
+  await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
+  return True
+
+async def seed_models(seed_dir: Union[str, Path]):
+  """Move model in resources folder of app to .cache/huggingface/hub"""
+  source_dir = Path(seed_dir)
+  dest_dir = await ensure_downloads_dir()
+  for path in source_dir.iterdir():
+    if path.is_dir() and path.name.startswith("models--"):
+      dest_path = dest_dir/path.name
+      if await aios.path.exists(dest_path): print('Skipping moving model to .cache directory')
+      else:
+        try: await aios.rename(str(path), str(dest_path))
+        except:
+          print(f"Error seeding model {path} to {dest_path}")
+          traceback.print_exc()
+
 async def fetch_file_list(session, repo_id, revision, path=""):
   api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
   url = f"{api_url}/{path}" if path else api_url
@@ -44,8 +75,9 @@ async def fetch_file_list(session, repo_id, revision, path=""):
     else:
       raise Exception(f"Failed to fetch file list: {response.status}")
 
-async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Optional[Callable[[int, int], None]] = None) -> Path:
+async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
   if (target_dir/path).exists(): return target_dir/path
+  await aios.makedirs((target_dir/path).parent, exist_ok=True)
   base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
   url = urljoin(base_url, path)
   headers = await get_auth_headers()
@@ -71,8 +103,7 @@ async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]
   target_dir = await ensure_downloads_dir()/repo_id.replace("/", "--")
   async with aiohttp.ClientSession() as session:
     index_file = await download_file(session, repo_id, revision, "model.safetensors.index.json", target_dir)
-    async with aiofiles.open(index_file, 'r') as f:
-      index_data = json.loads(await f.read())
+    async with aiofiles.open(index_file, 'r') as f: index_data = json.loads(await f.read())
     return index_data.get("weight_map")
 
 async def resolve_allow_patterns(shard: Shard, inference_engine_classname: str) -> list[str]:

+ 0 - 0
exo/download/hf/test_new_shard_download.py → exo/download/test_new_shard_download.py


+ 16 - 17
exo/inference/tokenizers.py

@@ -1,12 +1,11 @@
 import traceback
-from aiofiles import os as aios
 from os import PathLike
-from pathlib import Path
+from aiofiles import os as aios
 from typing import Union
 from transformers import AutoTokenizer, AutoProcessor
 import numpy as np
-from exo.download.hf.hf_helpers import get_local_snapshot_dir
 from exo.helpers import DEBUG
+from exo.download.new_shard_download import ensure_downloads_dir
 
 
 class DummyTokenizer:
@@ -24,25 +23,25 @@ class DummyTokenizer:
     return "dummy" * len(tokens)
 
 
-async def resolve_tokenizer(model_id: str):
-  if model_id == "dummy":
+async def resolve_tokenizer(repo_id: Union[str, PathLike]):
+  if repo_id == "dummy":
     return DummyTokenizer()
-  local_path = await get_local_snapshot_dir(model_id)
+  local_path = await ensure_downloads_dir()/str(repo_id).replace("/", "--")
   if DEBUG >= 2: print(f"Checking if local path exists to load tokenizer from local {local_path=}")
   try:
     if local_path and await aios.path.exists(local_path):
-      if DEBUG >= 2: print(f"Resolving tokenizer for {model_id=} from {local_path=}")
+      if DEBUG >= 2: print(f"Resolving tokenizer for {repo_id=} from {local_path=}")
       return await _resolve_tokenizer(local_path)
   except:
-    if DEBUG >= 5: print(f"Local check for {local_path=} failed. Resolving tokenizer for {model_id=} normally...")
+    if DEBUG >= 5: print(f"Local check for {local_path=} failed. Resolving tokenizer for {repo_id=} normally...")
     if DEBUG >= 5: traceback.print_exc()
-  return await _resolve_tokenizer(model_id)
+  return await _resolve_tokenizer(repo_id)
 
 
-async def _resolve_tokenizer(model_id_or_local_path: Union[str, PathLike]):
+async def _resolve_tokenizer(repo_id_or_local_path: Union[str, PathLike]):
   try:
-    if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id_or_local_path}")
-    processor = AutoProcessor.from_pretrained(model_id_or_local_path, use_fast=True if "Mistral-Large" in f"{model_id_or_local_path}" else False, trust_remote_code=True)
+    if DEBUG >= 4: print(f"Trying AutoProcessor for {repo_id_or_local_path}")
+    processor = AutoProcessor.from_pretrained(repo_id_or_local_path, use_fast=True if "Mistral-Large" in f"{repo_id_or_local_path}" else False, trust_remote_code=True)
     if not hasattr(processor, 'eos_token_id'):
       processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
     if not hasattr(processor, 'encode'):
@@ -51,14 +50,14 @@ async def _resolve_tokenizer(model_id_or_local_path: Union[str, PathLike]):
       processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
     return processor
   except Exception as e:
-    if DEBUG >= 4: print(f"Failed to load processor for {model_id_or_local_path}. Error: {e}")
+    if DEBUG >= 4: print(f"Failed to load processor for {repo_id_or_local_path}. Error: {e}")
     if DEBUG >= 4: print(traceback.format_exc())
 
   try:
-    if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id_or_local_path}")
-    return AutoTokenizer.from_pretrained(model_id_or_local_path, trust_remote_code=True)
+    if DEBUG >= 4: print(f"Trying AutoTokenizer for {repo_id_or_local_path}")
+    return AutoTokenizer.from_pretrained(repo_id_or_local_path, trust_remote_code=True)
   except Exception as e:
-    if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id_or_local_path}. Falling back to tinygrad tokenizer. Error: {e}")
+    if DEBUG >= 4: print(f"Failed to load tokenizer for {repo_id_or_local_path}. Falling back to tinygrad tokenizer. Error: {e}")
     if DEBUG >= 4: print(traceback.format_exc())
 
-  raise ValueError(f"[TODO] Unsupported model: {model_id_or_local_path}")
+  raise ValueError(f"[TODO] Unsupported model: {repo_id_or_local_path}")

+ 9 - 10
exo/main.py

@@ -23,19 +23,18 @@ from exo.networking.tailscale.tailscale_discovery import TailscaleDiscovery
 from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 from exo.api import ChatGPTAPI
-from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
-from exo.download.hf.new_shard_download import new_shard_downloader
+from exo.download.shard_download import ShardDownloader, NoopShardDownloader
+from exo.download.download_progress import RepoProgressEvent
+from exo.download.new_shard_download import new_shard_downloader, has_exo_home_read_access, has_exo_home_write_access, exo_home, seed_models
 from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses_and_interfaces, terminal_link, shutdown
 from exo.inference.shard import Shard
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.models import build_base_shard, get_repo
 from exo.viz.topology_viz import TopologyViz
-from exo.download.hf.hf_helpers import has_hf_home_read_access, has_hf_home_write_access, get_hf_home, move_models_to_hf
 import uvloop
 from contextlib import asynccontextmanager
 import concurrent.futures
-import socket
 import resource
 import psutil
 
@@ -321,13 +320,13 @@ async def train_model_cli(node: Node, inference_engine: InferenceEngine, model_n
 async def main():
   loop = asyncio.get_running_loop()
 
-  # Check HuggingFace directory permissions
-  hf_home, has_read, has_write = get_hf_home(), await has_hf_home_read_access(), await has_hf_home_write_access()
-  if DEBUG >= 1: print(f"Model storage directory: {hf_home}")
+  # Check exo directory permissions
+  home, has_read, has_write = exo_home(), await has_exo_home_read_access(), await has_exo_home_write_access()
+  if DEBUG >= 1: print(f"exo home directory: {home}")
   print(f"{has_read=}, {has_write=}")
   if not has_read or not has_write:
     print(f"""
-          WARNING: Limited permissions for model storage directory: {hf_home}.
+          WARNING: Limited permissions for exo home directory: {home}.
           This may prevent model downloads from working correctly.
           {"❌ No read access" if not has_read else ""}
           {"❌ No write access" if not has_write else ""}
@@ -336,9 +335,9 @@ async def main():
   if not args.models_seed_dir is None:
     try:
       models_seed_dir = clean_path(args.models_seed_dir)
-      await move_models_to_hf(models_seed_dir)
+      await seed_models(models_seed_dir)
     except Exception as e:
-      print(f"Error moving models to .cache/huggingface: {e}")
+      print(f"Error seeding models: {e}")
 
   def restore_cursor():
     if platform.system() != "Windows":

+ 1 - 1
exo/orchestration/node.py

@@ -13,7 +13,7 @@ from exo.topology.partitioning_strategy import Partition, PartitioningStrategy,
 from exo import DEBUG
 from exo.helpers import AsyncCallbackSystem
 from exo.viz.topology_viz import TopologyViz
-from exo.download.hf.hf_helpers import RepoProgressEvent
+from exo.download.download_progress import RepoProgressEvent
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
 from exo.download.shard_download import ShardDownloader
 

+ 1 - 1
exo/viz/test_topology_viz.py

@@ -5,7 +5,7 @@ from exo.viz.topology_viz import TopologyViz
 from exo.topology.topology import Topology
 from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
 from exo.topology.partitioning_strategy import Partition
-from exo.download.hf.hf_helpers import RepoProgressEvent, RepoFileProgressEvent
+from exo.download.download_progress import RepoProgressEvent
 
 
 def create_hf_repo_progress_event(

+ 1 - 1
exo/viz/topology_viz.py

@@ -4,7 +4,7 @@ from typing import List, Optional, Tuple, Dict
 from exo.helpers import exo_text, pretty_print_bytes, pretty_print_bytes_per_second
 from exo.topology.topology import Topology
 from exo.topology.partitioning_strategy import Partition
-from exo.download.hf.hf_helpers import RepoProgressEvent
+from exo.download.download_progress import RepoProgressEvent
 from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
 from rich.console import Console, Group
 from rich.text import Text