Browse Source

fix visual bug where frontend would show the full hf repo size, but in some cases that includes redundant files so we should use the model index in those cases too

Alex Cheema 5 months ago
parent
commit
b349e48b0d

+ 4 - 9
exo/api/chatgpt_api.py

@@ -14,7 +14,7 @@ from exo import DEBUG, VERSION
 from exo.helpers import PrefixDict, shutdown, get_exo_images_dir
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration import Node
-from exo.models import build_base_shard, model_cards, get_repo, get_model_id, get_supported_models, get_pretty_name
+from exo.models import build_base_shard, build_full_shard, model_cards, get_repo, get_supported_models, get_pretty_name
 from typing import Callable, Optional
 from PIL import Image
 import numpy as np
@@ -22,7 +22,7 @@ import base64
 from io import BytesIO
 import platform
 from exo.download.download_progress import RepoProgressEvent
-from exo.download.new_shard_download import ensure_downloads_dir, delete_model
+from exo.download.new_shard_download import delete_model
 import tempfile
 from exo.apputil import create_animation_mp4
 from collections import defaultdict
@@ -278,7 +278,7 @@ class ChatGPTAPI:
       await response.prepare(request)
       downloads = await self.node.shard_downloader.get_shard_download_status(self.inference_engine_classname)
       for (path, d) in downloads:
-        model_data = { get_model_id(d.repo_id): { "downloaded": d.downloaded_bytes == d.total_bytes, "download_percentage": 100 if d.downloaded_bytes == d.total_bytes else 100 * float(d.downloaded_bytes) / float(d.total_bytes), "total_size": d.total_bytes, "total_downloaded": d.downloaded_bytes } }
+        model_data = { d.shard.model_id: { "downloaded": d.downloaded_bytes == d.total_bytes, "download_percentage": 100 if d.downloaded_bytes == d.total_bytes else 100 * float(d.downloaded_bytes) / float(d.total_bytes), "total_size": d.total_bytes, "total_downloaded": d.downloaded_bytes } }
         await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
       await response.write(b"data: [DONE]\n\n")
       return response
@@ -551,11 +551,6 @@ class ChatGPTAPI:
       if DEBUG >= 2: traceback.print_exc()
       return web.json_response({"detail": f"Error deleting model: {str(e)}"}, status=500)
 
-    except Exception as e:
-      print(f"Error in handle_delete_model: {str(e)}")
-      traceback.print_exc()
-      return web.json_response({"detail": f"Server error: {str(e)}"}, status=500)
-
   async def handle_get_initial_models(self, request):
     model_data = {}
     for model_id in get_supported_models([[self.inference_engine_classname]]):
@@ -606,7 +601,7 @@ class ChatGPTAPI:
       model_name = data.get("model")
       if not model_name: return web.json_response({"error": "model parameter is required"}, status=400)
       if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400)
-      shard = build_base_shard(model_name, self.inference_engine_classname)
+      shard = build_full_shard(model_name, self.inference_engine_classname)
       if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
       asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard, self.inference_engine_classname))
 

+ 4 - 1
exo/download/download_progress.py

@@ -1,4 +1,5 @@
 from typing import Dict, Callable, Coroutine, Any, Literal
+from exo.inference.shard import Shard
 from dataclasses import dataclass
 from datetime import timedelta
 
@@ -30,6 +31,7 @@ class RepoFileProgressEvent:
 
 @dataclass
 class RepoProgressEvent:
+  shard: Shard
   repo_id: str
   repo_revision: str
   completed_files: int
@@ -44,7 +46,7 @@ class RepoProgressEvent:
 
   def to_dict(self):
     return {
-      "repo_id": self.repo_id, "repo_revision": self.repo_revision, "completed_files": self.completed_files, "total_files": self.total_files, "downloaded_bytes": self.downloaded_bytes,
+      "shard": self.shard.to_dict(), "repo_id": self.repo_id, "repo_revision": self.repo_revision, "completed_files": self.completed_files, "total_files": self.total_files, "downloaded_bytes": self.downloaded_bytes,
       "downloaded_bytes_this_session": self.downloaded_bytes_this_session, "total_bytes": self.total_bytes, "overall_speed": self.overall_speed, "overall_eta": self.overall_eta.total_seconds(),
       "file_progress": {k: v.to_dict()
                         for k, v in self.file_progress.items()}, "status": self.status
@@ -54,6 +56,7 @@ class RepoProgressEvent:
   def from_dict(cls, data):
     if 'overall_eta' in data: data['overall_eta'] = timedelta(seconds=data['overall_eta'])
     if 'file_progress' in data: data['file_progress'] = {k: RepoFileProgressEvent.from_dict(v) for k, v in data['file_progress'].items()}
+    if 'shard' in data: data['shard'] = Shard.from_dict(data['shard'])
 
     return cls(**data)
 

+ 22 - 13
exo/download/new_shard_download.py

@@ -5,7 +5,7 @@ from exo.download.hf.hf_helpers import get_hf_endpoint, get_auth_headers, filter
 from exo.download.shard_download import ShardDownloader
 from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent
 from exo.helpers import AsyncCallbackSystem, DEBUG
-from exo.models import get_supported_models, build_base_shard
+from exo.models import get_supported_models, build_full_shard
 import os
 import aiofiles.os as aios
 import aiohttp
@@ -18,10 +18,18 @@ import asyncio
 import json
 import traceback
 import shutil
+import tempfile
 
 def exo_home() -> Path:
   return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo"))
 
+def exo_tmp() -> Path:
+  return Path(tempfile.gettempdir())/"exo"
+
+async def ensure_exo_tmp() -> Path:
+  await aios.makedirs(exo_tmp(), exist_ok=True)
+  return exo_tmp()
+
 async def has_exo_home_read_access() -> bool:
   try: return await aios.access(exo_home(), os.R_OK)
   except OSError: return False
@@ -90,17 +98,17 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
       await aios.rename(temp_file.name, target_dir/path)
     return target_dir/path
 
-def calculate_repo_progress(repo_id: str, revision: str, file_progress: Dict[str, RepoFileProgressEvent], all_start_time: float) -> RepoProgressEvent:
+def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_progress: Dict[str, RepoFileProgressEvent], all_start_time: float) -> RepoProgressEvent:
   all_total_bytes = sum([p.total for p in file_progress.values()])
   all_downloaded_bytes = sum([p.downloaded for p in file_progress.values()])
   elapsed_time = time.time() - all_start_time
   all_speed = all_downloaded_bytes / elapsed_time if elapsed_time > 0 else 0
   all_eta = timedelta(seconds=(all_total_bytes - all_downloaded_bytes) / all_speed) if all_speed > 0 else timedelta(seconds=0)
   status = "not_started" if all_downloaded_bytes == 0 else "complete" if all_downloaded_bytes == all_total_bytes else "in_progress"
-  return RepoProgressEvent(repo_id, revision, len([p for p in file_progress.values() if p.downloaded == p.total]), len(file_progress), all_downloaded_bytes, all_downloaded_bytes, all_total_bytes, all_speed, all_eta, file_progress, status)
+  return RepoProgressEvent(shard, repo_id, revision, len([p for p in file_progress.values() if p.downloaded == p.total]), len(file_progress), all_downloaded_bytes, all_downloaded_bytes, all_total_bytes, all_speed, all_eta, file_progress, status)
 
 async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]:
-  target_dir = await ensure_downloads_dir()/repo_id.replace("/", "--")
+  target_dir = await ensure_exo_tmp()/repo_id.replace("/", "--")
   async with aiohttp.ClientSession() as session:
     index_file = await download_file(session, repo_id, revision, "model.safetensors.index.json", target_dir)
     async with aiofiles.open(index_file, 'r') as f: index_data = json.loads(await f.read())
@@ -110,12 +118,13 @@ async def resolve_allow_patterns(shard: Shard, inference_engine_classname: str)
   try:
     weight_map = await get_weight_map(get_repo(shard.model_id, inference_engine_classname))
     return get_allow_patterns(weight_map, shard)
-  except Exception as e:
-    if DEBUG >= 1: print(f"Error getting weight map for {shard.model_id=} and inference engine {inference_engine_classname}: {e}")
+  except:
+    if DEBUG >= 1: print(f"Error getting weight map for {shard.model_id=} and inference engine {inference_engine_classname}")
+    if DEBUG >= 1: traceback.print_exc()
     return ["*"]
 
 async def download_shard(shard: Shard, inference_engine_classname: str, on_progress: AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]], max_parallel_downloads: int = 6, skip_download: bool = False) -> tuple[Path, RepoProgressEvent]:
-  if DEBUG >= 6 and not skip_download: print(f"Downloading {shard.model_id=} for {inference_engine_classname}")
+  if DEBUG >= 2 and not skip_download: print(f"Downloading {shard.model_id=} for {inference_engine_classname}")
   repo_id = get_repo(shard.model_id, inference_engine_classname)
   revision = "main"
   target_dir = await ensure_downloads_dir()/repo_id.replace("/", "--")
@@ -124,8 +133,8 @@ async def download_shard(shard: Shard, inference_engine_classname: str, on_progr
   if repo_id is None:
     raise ValueError(f"No repo found for {shard.model_id=} and inference engine {inference_engine_classname}")
 
-  allow_patterns = await resolve_allow_patterns(shard, inference_engine_classname) if not skip_download else None
-  if DEBUG >= 3: print(f"Downloading {shard.model_id=} with {allow_patterns=}")
+  allow_patterns = await resolve_allow_patterns(shard, inference_engine_classname)
+  if DEBUG >= 2: print(f"Downloading {shard.model_id=} with {allow_patterns=}")
 
   all_start_time = time.time()
   async with aiohttp.ClientSession() as session:
@@ -137,8 +146,8 @@ async def download_shard(shard: Shard, inference_engine_classname: str, on_progr
       speed = curr_bytes / (time.time() - start_time)
       eta = timedelta(seconds=(total_bytes - curr_bytes) / speed)
       file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], curr_bytes, curr_bytes, total_bytes, speed, eta, "in_progress", start_time)
-      on_progress.trigger_all(shard, calculate_repo_progress(repo_id, revision, file_progress, all_start_time))
-      if DEBUG >= 2: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}")
+      on_progress.trigger_all(shard, calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time))
+      if DEBUG >= 6: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}")
     for file in filtered_file_list:
       downloaded_bytes = (await aios.stat(target_dir/file["path"])).st_size if await aios.path.exists(target_dir/file["path"]) else 0
       file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], downloaded_bytes, 0, file["size"], 0, timedelta(0), "not_started" if downloaded_bytes == 0 else "complete" if downloaded_bytes == file["size"] else "in_progress", time.time())
@@ -148,7 +157,7 @@ async def download_shard(shard: Shard, inference_engine_classname: str, on_progr
       async with semaphore:
         await download_file(session, repo_id, revision, file["path"], target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes))
     if not skip_download: await asyncio.gather(*[download_with_semaphore(file) for file in filtered_file_list])
-    final_repo_progress = calculate_repo_progress(repo_id, revision, file_progress, all_start_time)
+    final_repo_progress = calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time)
     on_progress.trigger_all(shard, final_repo_progress)
     return target_dir, final_repo_progress
 
@@ -208,7 +217,7 @@ class NewShardDownloader(ShardDownloader):
 
   async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
     if DEBUG >= 2: print("Getting shard download status for", inference_engine_name)
-    downloads = await asyncio.gather(*[download_shard(build_base_shard(model_id, inference_engine_name), inference_engine_name, self.on_progress, skip_download=True) for model_id in get_supported_models([[inference_engine_name]])], return_exceptions=True)
+    downloads = await asyncio.gather(*[download_shard(build_full_shard(model_id, inference_engine_name), inference_engine_name, self.on_progress, skip_download=True) for model_id in get_supported_models([[inference_engine_name]])], return_exceptions=True)
     if DEBUG >= 6: print("Downloaded shards:", downloads)
     if any(isinstance(d, Exception) for d in downloads) and DEBUG >= 1: print("Error downloading shards:", [d for d in downloads if isinstance(d, Exception)])
     return [d for d in downloads if not isinstance(d, Exception)]

+ 1 - 6
exo/download/test_new_shard_download.py

@@ -1,6 +1,5 @@
 from exo.download.new_shard_download import download_shard, NewShardDownloader
 from exo.inference.shard import Shard
-from exo.models import get_model_id
 from pathlib import Path
 import asyncio
 
@@ -11,10 +10,6 @@ async def test_new_shard_download():
   download_statuses = await shard_downloader.get_shard_download_status("MLXDynamicShardInferenceEngine")
   print({k: v for k, v in download_statuses if v.downloaded_bytes > 0})
 
-async def test_helpers():
-  print(get_model_id("mlx-community/Llama-3.3-70B-Instruct-4bit"))
-  print(get_model_id("fjsekljfd"))
-
 if __name__ == "__main__":
-  asyncio.run(test_helpers())
+  asyncio.run(test_new_shard_download())
 

+ 5 - 3
exo/models.py

@@ -235,9 +235,6 @@ pretty_name = {
 def get_repo(model_id: str, inference_engine_classname: str) -> Optional[str]:
   return model_cards.get(model_id, {}).get("repo", {}).get(inference_engine_classname, None)
 
-def get_model_id(repo_id: str) -> Optional[str]:
-  return next((model_id for model_id, card in model_cards.items() if repo_id in card["repo"].values()), None)
-
 def get_pretty_name(model_id: str) -> Optional[str]:
   return pretty_name.get(model_id, None)
 
@@ -248,6 +245,11 @@ def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional
     return None
   return Shard(model_id, 0, 0, n_layers)
 
+def build_full_shard(model_id: str, inference_engine_classname: str) -> Optional[Shard]:
+  base_shard = build_base_shard(model_id, inference_engine_classname)
+  if base_shard is None: return None
+  return Shard(base_shard.model_id, 0, base_shard.n_layers - 1, base_shard.n_layers)
+
 def get_supported_models(supported_inference_engine_lists: Optional[List[List[str]]] = None) -> List[str]:
   if not supported_inference_engine_lists:
     return list(model_cards.keys())