Browse Source

remove a lot of hf bloat

Alex Cheema 3 months ago
parent
commit
1df023023e

+ 11 - 35
exo/api/chatgpt_api.py

@@ -21,19 +21,17 @@ import numpy as np
 import base64
 import base64
 from io import BytesIO
 from io import BytesIO
 import platform
 import platform
-from exo.download.shard_download import RepoProgressEvent
+from exo.download.download_progress import RepoProgressEvent
+from exo.download.new_shard_download import ensure_downloads_dir, delete_model
+import tempfile
+from exo.apputil import create_animation_mp4
+from collections import defaultdict
 
 
 if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
 if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
   import mlx.core as mx
   import mlx.core as mx
 else:
 else:
   import numpy as mx
   import numpy as mx
 
 
-import tempfile
-import shutil
-from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
-from exo.apputil import create_animation_mp4
-from collections import defaultdict
-
 
 
 class Message:
 class Message:
   def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]], tools: Optional[List[Dict]] = None):
   def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]], tools: Optional[List[Dict]] = None):
@@ -545,35 +543,13 @@ class ChatGPTAPI:
       return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
       return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
 
 
   async def handle_delete_model(self, request):
   async def handle_delete_model(self, request):
+    model_id = request.match_info.get('model_name')
     try:
     try:
-      model_name = request.match_info.get('model_name')
-      if DEBUG >= 2: print(f"Attempting to delete model: {model_name}")
-
-      if not model_name or model_name not in model_cards:
-        return web.json_response({"detail": f"Invalid model name: {model_name}"}, status=400)
-
-      shard = build_base_shard(model_name, self.inference_engine_classname)
-      if not shard:
-        return web.json_response({"detail": "Could not build shard for model"}, status=400)
-
-      repo_id = get_repo(shard.model_id, self.inference_engine_classname)
-      if DEBUG >= 2: print(f"Repo ID for model: {repo_id}")
-
-      # Get the HF cache directory using the helper function
-      hf_home = get_hf_home()
-      cache_dir = get_repo_root(repo_id)
-
-      if DEBUG >= 2: print(f"Looking for model files in: {cache_dir}")
-
-      if os.path.exists(cache_dir):
-        if DEBUG >= 2: print(f"Found model files at {cache_dir}, deleting...")
-        try:
-          shutil.rmtree(cache_dir)
-          return web.json_response({"status": "success", "message": f"Model {model_name} deleted successfully", "path": str(cache_dir)})
-        except Exception as e:
-          return web.json_response({"detail": f"Failed to delete model files: {str(e)}"}, status=500)
-      else:
-        return web.json_response({"detail": f"Model files not found at {cache_dir}"}, status=404)
+      if await delete_model(model_id, self.inference_engine_classname): return web.json_response({"status": "success", "message": f"Model {model_id} deleted successfully"})
+      else: return web.json_response({"detail": f"Model {model_id} files not found"}, status=404)
+    except Exception as e:
+      if DEBUG >= 2: traceback.print_exc()
+      return web.json_response({"detail": f"Error deleting model: {str(e)}"}, status=500)
 
 
     except Exception as e:
     except Exception as e:
       print(f"Error in handle_delete_model: {str(e)}")
       print(f"Error in handle_delete_model: {str(e)}")

+ 3 - 411
exo/download/hf/hf_helpers.py

@@ -1,36 +1,16 @@
 import aiofiles.os as aios
 import aiofiles.os as aios
 from typing import Union
 from typing import Union
-import asyncio
-import aiohttp
-import json
 import os
 import os
-import sys
-import shutil
-from urllib.parse import urljoin
-from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
-from datetime import datetime, timedelta
+from typing import Callable, Optional, Dict, List, Union
 from fnmatch import fnmatch
 from fnmatch import fnmatch
 from pathlib import Path
 from pathlib import Path
-from typing import Generator, Iterable, TypeVar, TypedDict
-from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
-from exo.helpers import DEBUG, is_frozen
-from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback
+from typing import Generator, Iterable, TypeVar
+from exo.helpers import DEBUG
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 import aiofiles
 import aiofiles
 
 
 T = TypeVar("T")
 T = TypeVar("T")
 
 
-async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]:
-  refs_dir = get_repo_root(repo_id)/"refs"
-  refs_file = refs_dir/revision
-  if await aios.path.exists(refs_file):
-    async with aiofiles.open(refs_file, 'r') as f:
-      commit_hash = (await f.read()).strip()
-      snapshot_dir = get_repo_root(repo_id)/"snapshots"/commit_hash
-      return snapshot_dir
-  return None
-
-
 def filter_repo_objects(
 def filter_repo_objects(
   items: Iterable[T],
   items: Iterable[T],
   *,
   *,
@@ -48,14 +28,12 @@ def filter_repo_objects(
     ignore_patterns = [_add_wildcard_to_directories(p) for p in ignore_patterns]
     ignore_patterns = [_add_wildcard_to_directories(p) for p in ignore_patterns]
 
 
   if key is None:
   if key is None:
-
     def _identity(item: T) -> str:
     def _identity(item: T) -> str:
       if isinstance(item, str):
       if isinstance(item, str):
         return item
         return item
       if isinstance(item, Path):
       if isinstance(item, Path):
         return str(item)
         return str(item)
       raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.")
       raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.")
-
     key = _identity
     key = _identity
 
 
   for item in items:
   for item in items:
@@ -66,22 +44,18 @@ def filter_repo_objects(
       continue
       continue
     yield item
     yield item
 
 
-
 def _add_wildcard_to_directories(pattern: str) -> str:
 def _add_wildcard_to_directories(pattern: str) -> str:
   if pattern[-1] == "/":
   if pattern[-1] == "/":
     return pattern + "*"
     return pattern + "*"
   return pattern
   return pattern
 
 
-
 def get_hf_endpoint() -> str:
 def get_hf_endpoint() -> str:
   return os.environ.get('HF_ENDPOINT', "https://huggingface.co")
   return os.environ.get('HF_ENDPOINT', "https://huggingface.co")
 
 
-
 def get_hf_home() -> Path:
 def get_hf_home() -> Path:
   """Get the Hugging Face home directory."""
   """Get the Hugging Face home directory."""
   return Path(os.environ.get("HF_HOME", Path.home()/".cache"/"huggingface"))
   return Path(os.environ.get("HF_HOME", Path.home()/".cache"/"huggingface"))
 
 
-
 async def get_hf_token():
 async def get_hf_token():
   """Retrieve the Hugging Face token from the user's HF_HOME directory."""
   """Retrieve the Hugging Face token from the user's HF_HOME directory."""
   token_path = get_hf_home()/"token"
   token_path = get_hf_home()/"token"
@@ -90,7 +64,6 @@ async def get_hf_token():
       return (await f.read()).strip()
       return (await f.read()).strip()
   return None
   return None
 
 
-
 async def get_auth_headers():
 async def get_auth_headers():
   """Get authentication headers if a token is available."""
   """Get authentication headers if a token is available."""
   token = await get_hf_token()
   token = await get_hf_token()
@@ -98,324 +71,6 @@ async def get_auth_headers():
     return {"Authorization": f"Bearer {token}"}
     return {"Authorization": f"Bearer {token}"}
   return {}
   return {}
 
 
-
-def get_repo_root(repo_id: str) -> Path:
-  """Get the root directory for a given repo ID in the Hugging Face cache."""
-  sanitized_repo_id = str(repo_id).replace("/", "--")
-  return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"
-
-async def move_models_to_hf(seed_dir: Union[str, Path]):
-  """Move model in resources folder of app to .cache/huggingface/hub"""
-  source_dir = Path(seed_dir)
-  dest_dir = get_hf_home()/"hub"
-  await aios.makedirs(dest_dir, exist_ok=True)  
-  for path in source_dir.iterdir():
-    if path.is_dir() and path.name.startswith("models--"):
-      dest_path = dest_dir / path.name
-      if await aios.path.exists(dest_path):
-        print('Skipping moving model to .cache directory')
-      else:
-        try:
-          await aios.rename(str(path), str(dest_path))
-        except Exception as e:
-          print(f'Error moving model to .cache: {e}')
-    
-    
-    
-async def fetch_file_list(session, repo_id, revision, path=""):
-  api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
-  url = f"{api_url}/{path}" if path else api_url
-
-  headers = await get_auth_headers()
-  async with session.get(url, headers=headers) as response:
-    if response.status == 200:
-      data = await response.json()
-      files = []
-      for item in data:
-        if item["type"] == "file":
-          files.append({"path": item["path"], "size": item["size"]})
-        elif item["type"] == "directory":
-          subfiles = await fetch_file_list(session, repo_id, revision, item["path"])
-          files.extend(subfiles)
-      return files
-    else:
-      raise Exception(f"Failed to fetch file list: {response.status}")
-
-
-@retry(
-  stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=60), retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError, aiohttp.ClientResponseError)), reraise=True
-)
-async def download_file(
-  session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[RepoFileProgressCallback] = None, use_range_request: bool = True
-):
-  base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
-  url = urljoin(base_url, file_path)
-  local_path = os.path.join(save_directory, file_path)
-
-  await aios.makedirs(os.path.dirname(local_path), exist_ok=True)
-
-  # Check if file already exists and get its size
-  local_file_size = await aios.path.getsize(local_path) if await aios.path.exists(local_path) else 0
-
-  headers = await get_auth_headers()
-  if use_range_request:
-    headers["Range"] = f"bytes={local_file_size}-"
-
-  async with session.get(url, headers=headers) as response:
-    total_size = int(response.headers.get('Content-Length', 0))
-    downloaded_size = local_file_size
-    downloaded_this_session = 0
-    mode = 'ab' if use_range_request else 'wb'
-    percentage = await get_file_download_percentage(
-      session,
-      repo_id,
-      revision,
-      file_path,
-      Path(save_directory)
-    )
-    
-    if percentage == 100:
-      if DEBUG >= 2: print(f"File already downloaded: {file_path}")
-      if progress_callback:
-        await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, total_size, 0, total_size, 0, timedelta(0), "complete"))
-      return
-
-    if response.status == 200:
-      # File doesn't support range requests or we're not using them, start from beginning
-      mode = 'wb'
-      downloaded_size = 0
-    elif response.status == 206:
-      # Partial content, resume download
-      content_range = response.headers.get('Content-Range', '')
-      try:
-        total_size = int(content_range.split('/')[-1])
-      except ValueError:
-        if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
-        return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
-    elif response.status == 416:
-      # Range not satisfiable, get the actual file size
-      content_range = response.headers.get('Content-Range', '')
-      try:
-        total_size = int(content_range.split('/')[-1])
-        if downloaded_size == total_size:
-          if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}")
-          if progress_callback:
-            await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
-          return
-      except ValueError:
-        if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
-        return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
-    else:
-      raise aiohttp.ClientResponseError(response.request_info, response.history, status=response.status, message=f"Failed to download {file_path}: {response.status}")
-
-    if downloaded_size == total_size:
-      if progress_callback:
-        await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
-      return
-
-    DOWNLOAD_CHUNK_SIZE = 32768
-    start_time = datetime.now()
-    async with aiofiles.open(local_path, mode) as f:
-      async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
-        await f.write(chunk)
-        downloaded_size += len(chunk)
-        downloaded_this_session += len(chunk)
-        if progress_callback and total_size:
-          elapsed_time = (datetime.now() - start_time).total_seconds()
-          speed = int(downloaded_this_session/elapsed_time) if elapsed_time > 0 else 0
-          remaining_size = total_size - downloaded_size
-          eta = timedelta(seconds=remaining_size/speed) if speed > 0 else timedelta(0)
-          status = "in_progress" if downloaded_size < total_size else "complete"
-          if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}")
-          await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
-    if DEBUG >= 2: print(f"Downloaded: {file_path}")
-
-
-async def resolve_revision_to_commit_hash(repo_id: str, revision: str) -> str:
-  repo_root = get_repo_root(repo_id)
-  refs_dir = repo_root/"refs"
-  refs_file = refs_dir/revision
-
-  # Check if we have a cached commit hash
-  if await aios.path.exists(refs_file):
-    async with aiofiles.open(refs_file, 'r') as f:
-      commit_hash = (await f.read()).strip()
-      if DEBUG >= 2: print(f"Commit hash is already cached at {refs_file}: {commit_hash}")
-      return commit_hash
-
-  # Fetch the commit hash for the given revision
-  async with aiohttp.ClientSession() as session:
-    api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/revision/{revision}"
-    headers = await get_auth_headers()
-    async with session.get(api_url, headers=headers) as response:
-      if response.status != 200:
-        raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
-      revision_info = await response.json()
-      commit_hash = revision_info['sha']
-
-  # Cache the commit hash
-  await aios.makedirs(refs_dir, exist_ok=True)
-  async with aiofiles.open(refs_file, 'w') as f:
-    await f.write(commit_hash)
-
-  return commit_hash
-
-
-async def download_repo_files(
-  repo_id: str,
-  revision: str = "main",
-  progress_callback: Optional[RepoProgressCallback] = None,
-  allow_patterns: Optional[Union[List[str], str]] = None,
-  ignore_patterns: Optional[Union[List[str], str]] = None,
-  max_parallel_downloads: int = 4
-) -> Path:
-  repo_root = get_repo_root(repo_id)
-  snapshots_dir = repo_root/"snapshots"
-  cachedreqs_dir = repo_root/"cachedreqs"
-
-  # Ensure directories exist
-  await aios.makedirs(snapshots_dir, exist_ok=True)
-  await aios.makedirs(cachedreqs_dir, exist_ok=True)
-
-  # Resolve revision to commit hash
-  commit_hash = await resolve_revision_to_commit_hash(repo_id, revision)
-
-  # Set up the snapshot directory
-  snapshot_dir = snapshots_dir/commit_hash
-  await aios.makedirs(snapshot_dir, exist_ok=True)
-
-  # Set up the cached file list directory
-  cached_file_list_dir = cachedreqs_dir/commit_hash
-  await aios.makedirs(cached_file_list_dir, exist_ok=True)
-  cached_file_list_path = cached_file_list_dir/"fetch_file_list.json"
-
-  async with aiohttp.ClientSession() as session:
-    # Check if we have a cached file list
-    if await aios.path.exists(cached_file_list_path):
-      async with aiofiles.open(cached_file_list_path, 'r') as f:
-        file_list = json.loads(await f.read())
-      if DEBUG >= 2: print(f"Using cached file list from {cached_file_list_path}")
-    else:
-      file_list = await fetch_file_list(session, repo_id, revision)
-      # Cache the file list
-      async with aiofiles.open(cached_file_list_path, 'w') as f:
-        await f.write(json.dumps(file_list))
-      if DEBUG >= 2: print(f"Cached file list at {cached_file_list_path}")
-
-    model_index_exists = any(file["path"] == "model_index.json" for file in file_list)
-    if model_index_exists:
-      allow_patterns = ["**/*.json", "**/*.txt", "**/*model.safetensors", "*.json"]
-
-    filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
-    total_files = len(filtered_file_list)
-    total_bytes = sum(file["size"] for file in filtered_file_list)
-    file_progress: Dict[str, RepoFileProgressEvent] = {
-      file["path"]: RepoFileProgressEvent(repo_id, revision, file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started")
-      for file in filtered_file_list
-    }
-    start_time = datetime.now()
-
-    async def download_with_progress(file_info, progress_state):
-      local_path = snapshot_dir/file_info["path"]
-      if await aios.path.exists(local_path) and (await aios.stat(local_path)).st_size == file_info["size"]:
-        if DEBUG >= 2: print(f"File already fully downloaded: {file_info['path']}")
-        progress_state['completed_files'] += 1
-        progress_state['downloaded_bytes'] += file_info["size"]
-        file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], 0, file_info["size"], 0, timedelta(0), "complete")
-        if progress_callback:
-          elapsed_time = (datetime.now() - start_time).total_seconds()
-          overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
-          remaining_bytes = total_bytes - progress_state['downloaded_bytes']
-          overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
-          status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
-          await progress_callback(
-            RepoProgressEvent(
-              repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
-              overall_eta, file_progress, status
-            )
-          )
-        return
-
-      async def file_progress_callback(event: RepoFileProgressEvent):
-        progress_state['downloaded_bytes'] += event.downloaded - file_progress[event.file_path].downloaded
-        progress_state['downloaded_bytes_this_session'] += event.downloaded_this_session - file_progress[event.file_path].downloaded_this_session
-        file_progress[event.file_path] = event
-        if progress_callback:
-          elapsed_time = (datetime.now() - start_time).total_seconds()
-          overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
-          remaining_bytes = total_bytes - progress_state['downloaded_bytes']
-          overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
-          status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete"
-          await progress_callback(
-            RepoProgressEvent(
-              repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
-              overall_eta, file_progress, status
-            )
-          )
-
-      await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
-      progress_state['completed_files'] += 1
-      file_progress[
-        file_info["path"]
-      ] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete")
-      if progress_callback:
-        elapsed_time = (datetime.now() - start_time).total_seconds()
-        overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
-        remaining_bytes = total_bytes - progress_state['downloaded_bytes']
-        overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
-        status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
-        await progress_callback(
-          RepoProgressEvent(
-            repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
-            overall_eta, file_progress, status
-          )
-        )
-
-    progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0}
-
-    semaphore = asyncio.Semaphore(max_parallel_downloads)
-
-    async def download_with_semaphore(file_info):
-      async with semaphore:
-        await download_with_progress(file_info, progress_state)
-
-    tasks = [asyncio.create_task(download_with_semaphore(file_info)) for file_info in filtered_file_list]
-    await asyncio.gather(*tasks)
-
-  return snapshot_dir
-
-
-async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[str, str]]:
-  """
-    Retrieve the weight map from the model.safetensors.index.json file.
-
-    Args:
-        repo_id (str): The Hugging Face repository ID.
-        revision (str): The revision of the repository to use.
-
-    Returns:
-        Optional[Dict[str, str]]: The weight map if it exists, otherwise None.
-    """
-
-  # Download the index file
-  await download_repo_files(repo_id=repo_id, revision=revision, allow_patterns="model.safetensors.index.json")
-
-  # Check if the file exists
-  repo_root = get_repo_root(repo_id)
-  commit_hash = await resolve_revision_to_commit_hash(repo_id, revision)
-  snapshot_dir = repo_root/"snapshots"/commit_hash
-  index_file = next((f for f in await aios.listdir(snapshot_dir) if f.endswith("model.safetensors.index.json")), None)
-
-  if index_file:
-    index_file_path = snapshot_dir/index_file
-    if await aios.path.exists(index_file_path):
-      async with aiofiles.open(index_file_path, 'r') as f:
-        index_data = json.loads(await f.read())
-      return index_data.get("weight_map")
-
-  return None
-
-
 def extract_layer_num(tensor_name: str) -> Optional[int]:
 def extract_layer_num(tensor_name: str) -> Optional[int]:
   # This is a simple example and might need to be adjusted based on the actual naming convention
   # This is a simple example and might need to be adjusted based on the actual naming convention
   parts = tensor_name.split('.')
   parts = tensor_name.split('.')
@@ -424,7 +79,6 @@ def extract_layer_num(tensor_name: str) -> Optional[int]:
       return int(part)
       return int(part)
   return None
   return None
 
 
-
 def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
 def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
   default_patterns = set(["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"])
   default_patterns = set(["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"])
   shard_specific_patterns = set()
   shard_specific_patterns = set()
@@ -442,65 +96,3 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
     shard_specific_patterns = set(["*.safetensors"])
     shard_specific_patterns = set(["*.safetensors"])
   if DEBUG >= 3: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
   if DEBUG >= 3: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
   return list(default_patterns | shard_specific_patterns)
   return list(default_patterns | shard_specific_patterns)
-
-async def get_file_download_percentage(
-    session: aiohttp.ClientSession,
-    repo_id: str,
-    revision: str,
-    file_path: str,
-    snapshot_dir: Path,
-) -> float:
-  """
-    Calculate the download percentage for a file by comparing local and remote sizes.
-    """
-  try:
-    local_path = snapshot_dir / file_path
-    if not await aios.path.exists(local_path):
-      return 0
-
-    # Get local file size first
-    local_size = await aios.path.getsize(local_path)
-    if local_size == 0:
-      return 0
-
-    # Check remote size
-    base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
-    url = urljoin(base_url, file_path)
-    headers = await get_auth_headers()
-
-    # Use HEAD request with redirect following for all files
-    async with session.head(url, headers=headers, allow_redirects=True) as response:
-      if response.status != 200:
-        if DEBUG >= 2:
-          print(f"Failed to get remote file info for {file_path}: {response.status}")
-        return 0
-
-      remote_size = int(response.headers.get('Content-Length', 0))
-
-      if remote_size == 0:
-        if DEBUG >= 2:
-          print(f"Remote size is 0 for {file_path}")
-        return 0
-
-      # Only return 100% if sizes match exactly
-      if local_size == remote_size:
-        return 100.0
-
-      # Calculate percentage based on sizes
-      return (local_size / remote_size) * 100 if remote_size > 0 else 0
-
-  except Exception as e:
-    if DEBUG >= 2:
-      print(f"Error checking file download status for {file_path}: {e}")
-    return 0
-
-async def has_hf_home_read_access() -> bool:
-  hf_home = get_hf_home()
-  try: return await aios.access(hf_home, os.R_OK)
-  except OSError: return False
-
-async def has_hf_home_write_access() -> bool:
-  hf_home = get_hf_home()
-  try: return await aios.access(hf_home, os.W_OK)
-  except OSError: return False
-

+ 0 - 172
exo/download/hf/hf_shard_download.py

@@ -1,172 +0,0 @@
-import asyncio
-import traceback
-from pathlib import Path
-from typing import Dict, List, Tuple, Optional, Union
-from exo.inference.shard import Shard
-from exo.download.shard_download import ShardDownloader
-from exo.download.download_progress import RepoProgressEvent
-from exo.download.hf.hf_helpers import (
-    download_repo_files, RepoProgressEvent, get_weight_map, 
-    get_allow_patterns, get_repo_root, fetch_file_list, 
-    get_local_snapshot_dir, get_file_download_percentage,
-    filter_repo_objects
-)
-from exo.helpers import AsyncCallbackSystem, DEBUG
-from exo.models import model_cards, get_repo
-import aiohttp
-from aiofiles import os as aios
-
-
-class HFShardDownloader(ShardDownloader):
-  def __init__(self, quick_check: bool = False, max_parallel_downloads: int = 4):
-    self.quick_check = quick_check
-    self.max_parallel_downloads = max_parallel_downloads
-    self.active_downloads: Dict[Shard, asyncio.Task] = {}
-    self.completed_downloads: Dict[Shard, Path] = {}
-    self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
-    self.current_shard: Optional[Shard] = None
-    self.current_repo_id: Optional[str] = None
-    self.revision: str = "main"
-
-  async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
-    self.current_shard = shard
-    self.current_repo_id = get_repo(shard.model_id, inference_engine_name)
-    repo_name = get_repo(shard.model_id, inference_engine_name)
-    if shard in self.completed_downloads:
-      return self.completed_downloads[shard]
-    if self.quick_check:
-      repo_root = get_repo_root(repo_name)
-      snapshots_dir = repo_root/"snapshots"
-      if snapshots_dir.exists():
-        visible_dirs = [d for d in snapshots_dir.iterdir() if not d.name.startswith('.')]
-        if visible_dirs:
-          most_recent_dir = max(visible_dirs, key=lambda x: x.stat().st_mtime)
-          return most_recent_dir
-
-    # If a download on this shard is already in progress, keep that one
-    for active_shard in self.active_downloads:
-      if active_shard == shard:
-        if DEBUG >= 2: print(f"Download already in progress for {shard}. Keeping that one.")
-        return await self.active_downloads[shard]
-
-    # Cancel any downloads for this model_id on a different shard
-    existing_active_shards = [active_shard for active_shard in self.active_downloads.keys() if active_shard.model_id == shard.model_id]
-    for active_shard in existing_active_shards:
-      if DEBUG >= 2: print(f"Cancelling download for {active_shard} (replacing with {shard})")
-      task = self.active_downloads[active_shard]
-      task.cancel()
-      try:
-        await task
-      except asyncio.CancelledError:
-        pass  # This is expected when cancelling a task
-      except Exception as e:
-        if DEBUG >= 2: print(f"Error in cancelling download {active_shard}: {e}")
-        traceback.print_exc()
-    self.active_downloads = {active_shard: task for active_shard, task in self.active_downloads.items() if active_shard.model_id != shard.model_id}
-
-    # Start new download
-    download_task = asyncio.create_task(self._download_shard(shard, repo_name))
-    self.active_downloads[shard] = download_task
-    try:
-      path = await download_task
-      self.completed_downloads[shard] = path
-      return path
-    finally:
-      # Ensure the task is removed even if an exception occurs
-      print(f"Removing download task for {shard}: {shard in self.active_downloads}")
-      if shard in self.active_downloads:
-        self.active_downloads.pop(shard)
-
-  async def _download_shard(self, shard: Shard, repo_name: str) -> Path:
-    async def wrapped_progress_callback(event: RepoProgressEvent):
-      self._on_progress.trigger_all(shard, event)
-
-    weight_map = await get_weight_map(repo_name)
-    allow_patterns = get_allow_patterns(weight_map, shard)
-
-    return await download_repo_files(repo_name, progress_callback=wrapped_progress_callback, allow_patterns=allow_patterns, max_parallel_downloads=self.max_parallel_downloads)
-
-  @property
-  def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
-    return self._on_progress
-
-  async def get_shard_download_status(self) -> Optional[Dict[str, Union[float, int]]]:
-    if not self.current_shard or not self.current_repo_id:
-      if DEBUG >= 2:
-        print(f"No current shard or repo_id set: {self.current_shard=} {self.current_repo_id=}")
-      return None
-
-    try:
-      # If no snapshot directory exists, return None - no need to check remote files
-      snapshot_dir = await get_local_snapshot_dir(self.current_repo_id, self.revision)
-      if not snapshot_dir:
-        if DEBUG >= 2:
-          print(f"No snapshot directory found for {self.current_repo_id}")
-        return None
-
-      if not await aios.path.exists(snapshot_dir/"model_index.json"):
-      # Get the weight map to know what files we need
-        weight_map = await get_weight_map(self.current_repo_id, self.revision)
-        if not weight_map:
-          if DEBUG >= 2:
-            print(f"No weight map found for {self.current_repo_id}")
-          return None
-
-        # Get all files needed for this shard
-        patterns = get_allow_patterns(weight_map, self.current_shard)
-      else:
-        patterns = ["**/*.json", "**/*.txt", "**/*model.safetensors", "*.json"]
-
-
-      # Check download status for all relevant files
-      status = {}
-      total_bytes = 0
-      downloaded_bytes = 0
-
-      async with aiohttp.ClientSession() as session:
-        file_list = await fetch_file_list(session, self.current_repo_id, self.revision)
-        relevant_files = list(
-            filter_repo_objects(
-                file_list, allow_patterns=patterns, key=lambda x: x["path"]))
-
-        for file in relevant_files:
-          file_size = file["size"]
-          total_bytes += file_size
-
-          percentage = await get_file_download_percentage(
-              session,
-              self.current_repo_id,
-              self.revision,
-              file["path"],
-              snapshot_dir,
-          )
-          status[file["path"]] = percentage
-          downloaded_bytes += (file_size * (percentage / 100))
-
-        # Add overall progress weighted by file size
-        if total_bytes > 0:
-          status["overall"] = (downloaded_bytes / total_bytes) * 100
-        else:
-          status["overall"] = 0
-          
-        # Add total size in bytes
-        status["total_size"] = total_bytes
-        if status["overall"] != 100:
-          status["total_downloaded"] = downloaded_bytes
-        
-
-        if DEBUG >= 2:
-          print(f"Download calculation for {self.current_repo_id}:")
-          print(f"Total bytes: {total_bytes}")
-          print(f"Downloaded bytes: {downloaded_bytes}")
-        if DEBUG >= 3:
-          for file in relevant_files:
-            print(f"File {file['path']}: size={file['size']}, percentage={status[file['path']]}")
-
-      return status
-
-    except Exception as e:
-      if DEBUG >= 3:
-        print(f"Error getting shard download status: {e}")
-        traceback.print_exc()
-      return None

+ 34 - 3
exo/download/hf/new_shard_download.py → exo/download/new_shard_download.py

@@ -16,15 +16,46 @@ import time
 from datetime import timedelta
 from datetime import timedelta
 import asyncio
 import asyncio
 import json
 import json
+import traceback
+import shutil
 
 
 def exo_home() -> Path:
 def exo_home() -> Path:
   return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo"))
   return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo"))
 
 
+async def has_exo_home_read_access() -> bool:
+  try: return await aios.access(exo_home(), os.R_OK)
+  except OSError: return False
+
+async def has_exo_home_write_access() -> bool:
+  try: return await aios.access(exo_home(), os.W_OK)
+  except OSError: return False
+
 async def ensure_downloads_dir() -> Path:
 async def ensure_downloads_dir() -> Path:
   downloads_dir = exo_home()/"downloads"
   downloads_dir = exo_home()/"downloads"
   await aios.makedirs(downloads_dir, exist_ok=True)
   await aios.makedirs(downloads_dir, exist_ok=True)
   return downloads_dir
   return downloads_dir
 
 
+async def delete_model(model_id: str, inference_engine_name: str) -> bool:
+  repo_id = get_repo(model_id, inference_engine_name)
+  model_dir = await ensure_downloads_dir()/repo_id.replace("/", "--")
+  if not await aios.path.exists(model_dir): return False
+  await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
+  return True
+
+async def seed_models(seed_dir: Union[str, Path]):
+  """Move model in resources folder of app to .cache/huggingface/hub"""
+  source_dir = Path(seed_dir)
+  dest_dir = await ensure_downloads_dir()
+  for path in source_dir.iterdir():
+    if path.is_dir() and path.name.startswith("models--"):
+      dest_path = dest_dir/path.name
+      if await aios.path.exists(dest_path): print('Skipping moving model to .cache directory')
+      else:
+        try: await aios.rename(str(path), str(dest_path))
+        except:
+          print(f"Error seeding model {path} to {dest_path}")
+          traceback.print_exc()
+
 async def fetch_file_list(session, repo_id, revision, path=""):
 async def fetch_file_list(session, repo_id, revision, path=""):
   api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
   api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
   url = f"{api_url}/{path}" if path else api_url
   url = f"{api_url}/{path}" if path else api_url
@@ -44,8 +75,9 @@ async def fetch_file_list(session, repo_id, revision, path=""):
     else:
     else:
       raise Exception(f"Failed to fetch file list: {response.status}")
       raise Exception(f"Failed to fetch file list: {response.status}")
 
 
-async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Optional[Callable[[int, int], None]] = None) -> Path:
+async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
   if (target_dir/path).exists(): return target_dir/path
   if (target_dir/path).exists(): return target_dir/path
+  await aios.makedirs((target_dir/path).parent, exist_ok=True)
   base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
   base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
   url = urljoin(base_url, path)
   url = urljoin(base_url, path)
   headers = await get_auth_headers()
   headers = await get_auth_headers()
@@ -71,8 +103,7 @@ async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]
   target_dir = await ensure_downloads_dir()/repo_id.replace("/", "--")
   target_dir = await ensure_downloads_dir()/repo_id.replace("/", "--")
   async with aiohttp.ClientSession() as session:
   async with aiohttp.ClientSession() as session:
     index_file = await download_file(session, repo_id, revision, "model.safetensors.index.json", target_dir)
     index_file = await download_file(session, repo_id, revision, "model.safetensors.index.json", target_dir)
-    async with aiofiles.open(index_file, 'r') as f:
-      index_data = json.loads(await f.read())
+    async with aiofiles.open(index_file, 'r') as f: index_data = json.loads(await f.read())
     return index_data.get("weight_map")
     return index_data.get("weight_map")
 
 
 async def resolve_allow_patterns(shard: Shard, inference_engine_classname: str) -> list[str]:
 async def resolve_allow_patterns(shard: Shard, inference_engine_classname: str) -> list[str]:

+ 0 - 0
exo/download/hf/test_new_shard_download.py → exo/download/test_new_shard_download.py


+ 16 - 17
exo/inference/tokenizers.py

@@ -1,12 +1,11 @@
 import traceback
 import traceback
-from aiofiles import os as aios
 from os import PathLike
 from os import PathLike
-from pathlib import Path
+from aiofiles import os as aios
 from typing import Union
 from typing import Union
 from transformers import AutoTokenizer, AutoProcessor
 from transformers import AutoTokenizer, AutoProcessor
 import numpy as np
 import numpy as np
-from exo.download.hf.hf_helpers import get_local_snapshot_dir
 from exo.helpers import DEBUG
 from exo.helpers import DEBUG
+from exo.download.new_shard_download import ensure_downloads_dir
 
 
 
 
 class DummyTokenizer:
 class DummyTokenizer:
@@ -24,25 +23,25 @@ class DummyTokenizer:
     return "dummy" * len(tokens)
     return "dummy" * len(tokens)
 
 
 
 
-async def resolve_tokenizer(model_id: str):
-  if model_id == "dummy":
+async def resolve_tokenizer(repo_id: Union[str, PathLike]):
+  if repo_id == "dummy":
     return DummyTokenizer()
     return DummyTokenizer()
-  local_path = await get_local_snapshot_dir(model_id)
+  local_path = await ensure_downloads_dir()/str(repo_id).replace("/", "--")
   if DEBUG >= 2: print(f"Checking if local path exists to load tokenizer from local {local_path=}")
   if DEBUG >= 2: print(f"Checking if local path exists to load tokenizer from local {local_path=}")
   try:
   try:
     if local_path and await aios.path.exists(local_path):
     if local_path and await aios.path.exists(local_path):
-      if DEBUG >= 2: print(f"Resolving tokenizer for {model_id=} from {local_path=}")
+      if DEBUG >= 2: print(f"Resolving tokenizer for {repo_id=} from {local_path=}")
       return await _resolve_tokenizer(local_path)
       return await _resolve_tokenizer(local_path)
   except:
   except:
-    if DEBUG >= 5: print(f"Local check for {local_path=} failed. Resolving tokenizer for {model_id=} normally...")
+    if DEBUG >= 5: print(f"Local check for {local_path=} failed. Resolving tokenizer for {repo_id=} normally...")
     if DEBUG >= 5: traceback.print_exc()
     if DEBUG >= 5: traceback.print_exc()
-  return await _resolve_tokenizer(model_id)
+  return await _resolve_tokenizer(repo_id)
 
 
 
 
-async def _resolve_tokenizer(model_id_or_local_path: Union[str, PathLike]):
+async def _resolve_tokenizer(repo_id_or_local_path: Union[str, PathLike]):
   try:
   try:
-    if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id_or_local_path}")
-    processor = AutoProcessor.from_pretrained(model_id_or_local_path, use_fast=True if "Mistral-Large" in f"{model_id_or_local_path}" else False, trust_remote_code=True)
+    if DEBUG >= 4: print(f"Trying AutoProcessor for {repo_id_or_local_path}")
+    processor = AutoProcessor.from_pretrained(repo_id_or_local_path, use_fast=True if "Mistral-Large" in f"{repo_id_or_local_path}" else False, trust_remote_code=True)
     if not hasattr(processor, 'eos_token_id'):
     if not hasattr(processor, 'eos_token_id'):
       processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
       processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
     if not hasattr(processor, 'encode'):
     if not hasattr(processor, 'encode'):
@@ -51,14 +50,14 @@ async def _resolve_tokenizer(model_id_or_local_path: Union[str, PathLike]):
       processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
       processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
     return processor
     return processor
   except Exception as e:
   except Exception as e:
-    if DEBUG >= 4: print(f"Failed to load processor for {model_id_or_local_path}. Error: {e}")
+    if DEBUG >= 4: print(f"Failed to load processor for {repo_id_or_local_path}. Error: {e}")
     if DEBUG >= 4: print(traceback.format_exc())
     if DEBUG >= 4: print(traceback.format_exc())
 
 
   try:
   try:
-    if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id_or_local_path}")
-    return AutoTokenizer.from_pretrained(model_id_or_local_path, trust_remote_code=True)
+    if DEBUG >= 4: print(f"Trying AutoTokenizer for {repo_id_or_local_path}")
+    return AutoTokenizer.from_pretrained(repo_id_or_local_path, trust_remote_code=True)
   except Exception as e:
   except Exception as e:
-    if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id_or_local_path}. Falling back to tinygrad tokenizer. Error: {e}")
+    if DEBUG >= 4: print(f"Failed to load tokenizer for {repo_id_or_local_path}. Falling back to tinygrad tokenizer. Error: {e}")
     if DEBUG >= 4: print(traceback.format_exc())
     if DEBUG >= 4: print(traceback.format_exc())
 
 
-  raise ValueError(f"[TODO] Unsupported model: {model_id_or_local_path}")
+  raise ValueError(f"[TODO] Unsupported model: {repo_id_or_local_path}")

+ 9 - 10
exo/main.py

@@ -23,19 +23,18 @@ from exo.networking.tailscale.tailscale_discovery import TailscaleDiscovery
 from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
 from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 from exo.api import ChatGPTAPI
 from exo.api import ChatGPTAPI
-from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
-from exo.download.hf.new_shard_download import new_shard_downloader
+from exo.download.shard_download import ShardDownloader, NoopShardDownloader
+from exo.download.download_progress import RepoProgressEvent
+from exo.download.new_shard_download import new_shard_downloader, has_exo_home_read_access, has_exo_home_write_access, exo_home, seed_models
 from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses_and_interfaces, terminal_link, shutdown
 from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses_and_interfaces, terminal_link, shutdown
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.models import build_base_shard, get_repo
 from exo.models import build_base_shard, get_repo
 from exo.viz.topology_viz import TopologyViz
 from exo.viz.topology_viz import TopologyViz
-from exo.download.hf.hf_helpers import has_hf_home_read_access, has_hf_home_write_access, get_hf_home, move_models_to_hf
 import uvloop
 import uvloop
 from contextlib import asynccontextmanager
 from contextlib import asynccontextmanager
 import concurrent.futures
 import concurrent.futures
-import socket
 import resource
 import resource
 import psutil
 import psutil
 
 
@@ -321,13 +320,13 @@ async def train_model_cli(node: Node, inference_engine: InferenceEngine, model_n
 async def main():
 async def main():
   loop = asyncio.get_running_loop()
   loop = asyncio.get_running_loop()
 
 
-  # Check HuggingFace directory permissions
-  hf_home, has_read, has_write = get_hf_home(), await has_hf_home_read_access(), await has_hf_home_write_access()
-  if DEBUG >= 1: print(f"Model storage directory: {hf_home}")
+  # Check exo directory permissions
+  home, has_read, has_write = exo_home(), await has_exo_home_read_access(), await has_exo_home_write_access()
+  if DEBUG >= 1: print(f"exo home directory: {home}")
   print(f"{has_read=}, {has_write=}")
   print(f"{has_read=}, {has_write=}")
   if not has_read or not has_write:
   if not has_read or not has_write:
     print(f"""
     print(f"""
-          WARNING: Limited permissions for model storage directory: {hf_home}.
+          WARNING: Limited permissions for exo home directory: {home}.
           This may prevent model downloads from working correctly.
           This may prevent model downloads from working correctly.
           {"❌ No read access" if not has_read else ""}
           {"❌ No read access" if not has_read else ""}
           {"❌ No write access" if not has_write else ""}
           {"❌ No write access" if not has_write else ""}
@@ -336,9 +335,9 @@ async def main():
   if not args.models_seed_dir is None:
   if not args.models_seed_dir is None:
     try:
     try:
       models_seed_dir = clean_path(args.models_seed_dir)
       models_seed_dir = clean_path(args.models_seed_dir)
-      await move_models_to_hf(models_seed_dir)
+      await seed_models(models_seed_dir)
     except Exception as e:
     except Exception as e:
-      print(f"Error moving models to .cache/huggingface: {e}")
+      print(f"Error seeding models: {e}")
 
 
   def restore_cursor():
   def restore_cursor():
     if platform.system() != "Windows":
     if platform.system() != "Windows":

+ 1 - 1
exo/orchestration/node.py

@@ -13,7 +13,7 @@ from exo.topology.partitioning_strategy import Partition, PartitioningStrategy,
 from exo import DEBUG
 from exo import DEBUG
 from exo.helpers import AsyncCallbackSystem
 from exo.helpers import AsyncCallbackSystem
 from exo.viz.topology_viz import TopologyViz
 from exo.viz.topology_viz import TopologyViz
-from exo.download.hf.hf_helpers import RepoProgressEvent
+from exo.download.download_progress import RepoProgressEvent
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
 from exo.download.shard_download import ShardDownloader
 from exo.download.shard_download import ShardDownloader
 
 

+ 1 - 1
exo/viz/test_topology_viz.py

@@ -5,7 +5,7 @@ from exo.viz.topology_viz import TopologyViz
 from exo.topology.topology import Topology
 from exo.topology.topology import Topology
 from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
 from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
 from exo.topology.partitioning_strategy import Partition
 from exo.topology.partitioning_strategy import Partition
-from exo.download.hf.hf_helpers import RepoProgressEvent, RepoFileProgressEvent
+from exo.download.download_progress import RepoProgressEvent
 
 
 
 
 def create_hf_repo_progress_event(
 def create_hf_repo_progress_event(

+ 1 - 1
exo/viz/topology_viz.py

@@ -4,7 +4,7 @@ from typing import List, Optional, Tuple, Dict
 from exo.helpers import exo_text, pretty_print_bytes, pretty_print_bytes_per_second
 from exo.helpers import exo_text, pretty_print_bytes, pretty_print_bytes_per_second
 from exo.topology.topology import Topology
 from exo.topology.topology import Topology
 from exo.topology.partitioning_strategy import Partition
 from exo.topology.partitioning_strategy import Partition
-from exo.download.hf.hf_helpers import RepoProgressEvent
+from exo.download.download_progress import RepoProgressEvent
 from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
 from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
 from rich.console import Console, Group
 from rich.console import Console, Group
 from rich.text import Text
 from rich.text import Text