瀏覽代碼

rewrite ShardDownloader, simplify significantly

Alex Cheema 3 月之前
父節點
當前提交
b89495f444

+ 11 - 40
exo/api/chatgpt_api.py

@@ -11,17 +11,17 @@ import aiohttp_cors
 import traceback
 import traceback
 import signal
 import signal
 from exo import DEBUG, VERSION
 from exo import DEBUG, VERSION
-from exo.download.download_progress import RepoProgressEvent
 from exo.helpers import PrefixDict, shutdown, get_exo_images_dir
 from exo.helpers import PrefixDict, shutdown, get_exo_images_dir
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration import Node
 from exo.orchestration import Node
-from exo.models import build_base_shard, model_cards, get_repo, pretty_name
+from exo.models import build_base_shard, model_cards, get_repo, get_model_id, get_supported_models, get_pretty_name
 from typing import Callable, Optional
 from typing import Callable, Optional
 from PIL import Image
 from PIL import Image
 import numpy as np
 import numpy as np
 import base64
 import base64
 from io import BytesIO
 from io import BytesIO
 import platform
 import platform
+from exo.download.shard_download import RepoProgressEvent
 
 
 if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
 if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
   import mlx.core as mx
   import mlx.core as mx
@@ -29,7 +29,6 @@ else:
   import numpy as mx
   import numpy as mx
 
 
 import tempfile
 import tempfile
-from exo.download.hf.hf_shard_download import HFShardDownloader
 import shutil
 import shutil
 from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
 from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
 from exo.apputil import create_animation_mp4
 from exo.apputil import create_animation_mp4
@@ -277,41 +276,12 @@ class ChatGPTAPI:
 
 
   async def handle_model_support(self, request):
   async def handle_model_support(self, request):
     try:
     try:
-      response = web.StreamResponse(status=200, reason='OK', headers={
-        'Content-Type': 'text/event-stream',
-        'Cache-Control': 'no-cache',
-        'Connection': 'keep-alive',
-      })
+      response = web.StreamResponse(status=200, reason='OK', headers={ 'Content-Type': 'text/event-stream', 'Cache-Control': 'no-cache', 'Connection': 'keep-alive' })
       await response.prepare(request)
       await response.prepare(request)
-
-      async def process_model(model_name, pretty):
-        if model_name in model_cards:
-          model_info = model_cards[model_name]
-
-          if self.inference_engine_classname in model_info.get("repo", {}):
-            shard = build_base_shard(model_name, self.inference_engine_classname)
-            if shard:
-              downloader = HFShardDownloader(quick_check=True)
-              downloader.current_shard = shard
-              downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
-              status = await downloader.get_shard_download_status()
-
-              download_percentage = status.get("overall") if status else None
-              total_size = status.get("total_size") if status else None
-              total_downloaded = status.get("total_downloaded") if status else False
-
-              model_data = {
-                model_name: {
-                  "name": pretty, "downloaded": download_percentage == 100 if download_percentage is not None else False, "download_percentage": download_percentage, "total_size": total_size,
-                  "total_downloaded": total_downloaded
-                }
-              }
-
-              await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
-
-      # Process all models in parallel
-      await asyncio.gather(*[process_model(model_name, pretty) for model_name, pretty in pretty_name.items()])
-
+      downloads = await self.node.shard_downloader.get_shard_download_status(self.inference_engine_classname)
+      for (path, d) in downloads:
+        model_data = { get_model_id(d.repo_id): { "downloaded": d.downloaded_bytes == d.total_bytes, "download_percentage": 100 if d.downloaded_bytes == d.total_bytes else 100 * float(d.downloaded_bytes) / float(d.total_bytes), "total_size": d.total_bytes, "total_downloaded": d.downloaded_bytes } }
+        await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
       await response.write(b"data: [DONE]\n\n")
       await response.write(b"data: [DONE]\n\n")
       return response
       return response
 
 
@@ -348,6 +318,7 @@ class ChatGPTAPI:
     progress_data = {}
     progress_data = {}
     for node_id, progress_event in self.node.node_download_progress.items():
     for node_id, progress_event in self.node.node_download_progress.items():
       if isinstance(progress_event, RepoProgressEvent):
       if isinstance(progress_event, RepoProgressEvent):
+        if progress_event.status != "in_progress": continue
         progress_data[node_id] = progress_event.to_dict()
         progress_data[node_id] = progress_event.to_dict()
       else:
       else:
         print(f"Unknown progress event type: {type(progress_event)}. {progress_event}")
         print(f"Unknown progress event type: {type(progress_event)}. {progress_event}")
@@ -611,9 +582,9 @@ class ChatGPTAPI:
 
 
   async def handle_get_initial_models(self, request):
   async def handle_get_initial_models(self, request):
     model_data = {}
     model_data = {}
-    for model_name, pretty in pretty_name.items():
-      model_data[model_name] = {
-        "name": pretty,
+    for model_id in get_supported_models([[self.inference_engine_classname]]):
+      model_data[model_id] = {
+        "name": get_pretty_name(model_id),
         "downloaded": None,  # Initially unknown
         "downloaded": None,  # Initially unknown
         "download_percentage": None,  # Change from 0 to null
         "download_percentage": None,  # Change from 0 to null
         "total_size": None,
         "total_size": None,

+ 2 - 1
exo/download/download_progress.py

@@ -14,11 +14,12 @@ class RepoFileProgressEvent:
   speed: int
   speed: int
   eta: timedelta
   eta: timedelta
   status: Literal["not_started", "in_progress", "complete"]
   status: Literal["not_started", "in_progress", "complete"]
+  start_time: float
 
 
   def to_dict(self):
   def to_dict(self):
     return {
     return {
       "repo_id": self.repo_id, "repo_revision": self.repo_revision, "file_path": self.file_path, "downloaded": self.downloaded, "downloaded_this_session": self.downloaded_this_session,
       "repo_id": self.repo_id, "repo_revision": self.repo_revision, "file_path": self.file_path, "downloaded": self.downloaded, "downloaded_this_session": self.downloaded_this_session,
-      "total": self.total, "speed": self.speed, "eta": self.eta.total_seconds(), "status": self.status
+      "total": self.total, "speed": self.speed, "eta": self.eta.total_seconds(), "status": self.status, "start_time": self.start_time
     }
     }
 
 
   @classmethod
   @classmethod

+ 0 - 1
exo/download/hf/hf_helpers.py

@@ -209,7 +209,6 @@ async def download_file(
       raise aiohttp.ClientResponseError(response.request_info, response.history, status=response.status, message=f"Failed to download {file_path}: {response.status}")
       raise aiohttp.ClientResponseError(response.request_info, response.history, status=response.status, message=f"Failed to download {file_path}: {response.status}")
 
 
     if downloaded_size == total_size:
     if downloaded_size == total_size:
-      print(f"File already downloaded: {file_path}")
       if progress_callback:
       if progress_callback:
         await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
         await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
       return
       return

+ 180 - 0
exo/download/hf/new_shard_download.py

@@ -0,0 +1,180 @@
+from exo.inference.shard import Shard
+from exo.models import get_repo
+from pathlib import Path
+from exo.download.hf.hf_helpers import get_hf_endpoint, get_auth_headers, filter_repo_objects, get_allow_patterns
+from exo.download.shard_download import ShardDownloader
+from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent
+from exo.helpers import AsyncCallbackSystem, DEBUG
+from exo.models import get_supported_models, build_base_shard
+import os
+import aiofiles.os as aios
+import aiohttp
+import aiofiles
+from urllib.parse import urljoin
+from typing import Optional, Callable, Union, Tuple, Dict
+import time
+from datetime import timedelta
+import asyncio
+import json
+
+def exo_home() -> Path:
+  return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo"))
+
+async def ensure_downloads_dir() -> Path:
+  downloads_dir = exo_home()/"downloads"
+  await aios.makedirs(downloads_dir, exist_ok=True)
+  return downloads_dir
+
+async def fetch_file_list(session, repo_id, revision, path=""):
+  api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
+  url = f"{api_url}/{path}" if path else api_url
+
+  headers = await get_auth_headers()
+  async with session.get(url, headers=headers) as response:
+    if response.status == 200:
+      data = await response.json()
+      files = []
+      for item in data:
+        if item["type"] == "file":
+          files.append({"path": item["path"], "size": item["size"]})
+        elif item["type"] == "directory":
+          subfiles = await fetch_file_list(session, repo_id, revision, item["path"])
+          files.extend(subfiles)
+      return files
+    else:
+      raise Exception(f"Failed to fetch file list: {response.status}")
+
+async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Optional[Callable[[int, int], None]] = None) -> Path:
+  if (target_dir/path).exists(): return target_dir/path
+  base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
+  url = urljoin(base_url, path)
+  headers = await get_auth_headers()
+  async with session.get(url, headers=headers) as r:
+    assert r.status == 200, r.status
+    length = int(r.headers.get('content-length', 0))
+    n_read = 0
+    async with aiofiles.tempfile.NamedTemporaryFile(dir=target_dir, delete=False) as temp_file:
+      while chunk := await r.content.read(16384): on_progress(n_read := n_read + await temp_file.write(chunk), length)
+      await aios.rename(temp_file.name, target_dir/path)
+    return target_dir/path
+
+def calculate_repo_progress(repo_id: str, revision: str, file_progress: Dict[str, RepoFileProgressEvent], all_start_time: float) -> RepoProgressEvent:
+  all_total_bytes = sum([p.total for p in file_progress.values()])
+  all_downloaded_bytes = sum([p.downloaded for p in file_progress.values()])
+  elapsed_time = time.time() - all_start_time
+  all_speed = all_downloaded_bytes / elapsed_time if elapsed_time > 0 else 0
+  all_eta = timedelta(seconds=(all_total_bytes - all_downloaded_bytes) / all_speed) if all_speed > 0 else timedelta(seconds=0)
+  status = "not_started" if all_downloaded_bytes == 0 else "complete" if all_downloaded_bytes == all_total_bytes else "in_progress"
+  return RepoProgressEvent(repo_id, revision, len([p for p in file_progress.values() if p.downloaded == p.total]), len(file_progress), all_downloaded_bytes, all_downloaded_bytes, all_total_bytes, all_speed, all_eta, file_progress, status)
+
+async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]:
+  target_dir = await ensure_downloads_dir()/repo_id.replace("/", "--")
+  async with aiohttp.ClientSession() as session:
+    index_file = await download_file(session, repo_id, revision, "model.safetensors.index.json", target_dir)
+    async with aiofiles.open(index_file, 'r') as f:
+      index_data = json.loads(await f.read())
+    return index_data.get("weight_map")
+
+async def resolve_allow_patterns(shard: Shard, inference_engine_classname: str) -> list[str]:
+  weight_map = await get_weight_map(get_repo(shard.model_id, inference_engine_classname))
+  return get_allow_patterns(weight_map, shard)
+
+async def download_shard(shard: Shard, inference_engine_classname: str, on_progress: AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]], max_parallel_downloads: int = 6, skip_download: bool = False) -> tuple[Path, RepoProgressEvent]:
+  if DEBUG >= 2 and not skip_download: print(f"Downloading {shard.model_id=} for {inference_engine_classname}")
+  repo_id = get_repo(shard.model_id, inference_engine_classname)
+  revision = "main"
+  target_dir = await ensure_downloads_dir()/repo_id.replace("/", "--")
+  if not skip_download: await aios.makedirs(target_dir, exist_ok=True)
+
+  if repo_id is None:
+    raise ValueError(f"No repo found for {shard.model_id=} and inference engine {inference_engine_classname}")
+
+  allow_patterns = await resolve_allow_patterns(shard, inference_engine_classname) if not skip_download else None
+  if DEBUG >= 3: print(f"Downloading {shard.model_id=} with {allow_patterns=}")
+
+  all_start_time = time.time()
+  async with aiohttp.ClientSession() as session:
+    file_list = await fetch_file_list(session, repo_id, revision)
+    filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, key=lambda x: x["path"]))
+    file_progress: Dict[str, RepoFileProgressEvent] = {}
+    def on_progress_wrapper(file: dict, curr_bytes: int, total_bytes: int):
+      start_time = file_progress[file["path"]].start_time if file["path"] in file_progress else time.time()
+      speed = curr_bytes / (time.time() - start_time)
+      eta = timedelta(seconds=(total_bytes - curr_bytes) / speed)
+      file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], curr_bytes, curr_bytes, total_bytes, speed, eta, "in_progress", start_time)
+      on_progress.trigger_all(shard, calculate_repo_progress(repo_id, revision, file_progress, all_start_time))
+      if DEBUG >= 2: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}")
+    for file in filtered_file_list:
+      downloaded_bytes = (await aios.stat(target_dir/file["path"])).st_size if await aios.path.exists(target_dir/file["path"]) else 0
+      file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], downloaded_bytes, 0, file["size"], 0, timedelta(0), "not_started" if downloaded_bytes == 0 else "complete" if downloaded_bytes == file["size"] else "in_progress", time.time())
+
+    semaphore = asyncio.Semaphore(max_parallel_downloads)
+    async def download_with_semaphore(file):
+      async with semaphore:
+        await download_file(session, repo_id, revision, file["path"], target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes))
+    if not skip_download: await asyncio.gather(*[download_with_semaphore(file) for file in filtered_file_list])
+    final_repo_progress = calculate_repo_progress(repo_id, revision, file_progress, all_start_time)
+    on_progress.trigger_all(shard, final_repo_progress)
+    return target_dir, final_repo_progress
+
+def new_shard_downloader() -> ShardDownloader:
+  return SingletonShardDownloader(CachedShardDownloader(NewShardDownloader()))
+
+class SingletonShardDownloader(ShardDownloader):
+  def __init__(self, shard_downloader: ShardDownloader):
+    self.shard_downloader = shard_downloader
+    self.active_downloads: Dict[Shard, asyncio.Task] = {}
+
+  @property
+  def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
+    return self.shard_downloader.on_progress
+
+  async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
+    if shard not in self.active_downloads: self.active_downloads[shard] = asyncio.create_task(self.shard_downloader.ensure_shard(shard, inference_engine_name))
+    try: return await self.active_downloads[shard]
+    finally:
+      if shard in self.active_downloads and self.active_downloads[shard].done(): del self.active_downloads[shard]
+
+  async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
+    return await self.shard_downloader.get_shard_download_status(inference_engine_name)
+
+class CachedShardDownloader(ShardDownloader):
+  def __init__(self, shard_downloader: ShardDownloader):
+    self.shard_downloader = shard_downloader
+    self.cache: Dict[tuple[str, Shard], Path] = {}
+
+  @property
+  def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
+    return self.shard_downloader.on_progress
+
+  async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
+    if (inference_engine_name, shard) in self.cache:
+      if DEBUG >= 2: print(f"ensure_shard cache hit {shard=} for {inference_engine_name}")
+      return self.cache[(inference_engine_name, shard)]
+    if DEBUG >= 2: print(f"ensure_shard cache miss {shard=} for {inference_engine_name}")
+    target_dir = await self.shard_downloader.ensure_shard(shard, inference_engine_name)
+    self.cache[(inference_engine_name, shard)] = target_dir
+    return target_dir
+
+  async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
+    return await self.shard_downloader.get_shard_download_status(inference_engine_name)
+
+class NewShardDownloader(ShardDownloader):
+  def __init__(self):
+    self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
+
+  @property
+  def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
+    return self._on_progress
+
+  async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
+    target_dir, _ = await download_shard(shard, inference_engine_name, self.on_progress)
+    return target_dir
+
+  async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
+    if DEBUG >= 2: print("Getting shard download status for", inference_engine_name)
+    downloads = await asyncio.gather(*[download_shard(build_base_shard(model_id, inference_engine_name), inference_engine_name, self.on_progress, skip_download=True) for model_id in get_supported_models([[inference_engine_name]])], return_exceptions=True)
+    if DEBUG >= 6: print("Downloaded shards:", downloads)
+    if any(isinstance(d, Exception) for d in downloads) and DEBUG >= 1: print("Error downloading shards:", [d for d in downloads if isinstance(d, Exception)])
+    return [d for d in downloads if not isinstance(d, Exception)]
+

+ 20 - 0
exo/download/hf/test_new_shard_download.py

@@ -0,0 +1,20 @@
+from exo.download.hf.new_shard_download import download_shard, NewShardDownloader
+from exo.inference.shard import Shard
+from exo.models import get_model_id
+from pathlib import Path
+import asyncio
+
+async def test_new_shard_download():
+  shard_downloader = NewShardDownloader()
+  shard_downloader.on_progress.register("test").on_next(lambda shard, event: print(shard, event))
+  await shard_downloader.ensure_shard(Shard(model_id="llama-3.2-1b", start_layer=0, end_layer=0, n_layers=16), "MLXDynamicShardInferenceEngine")
+  download_statuses = await shard_downloader.get_shard_download_status("MLXDynamicShardInferenceEngine")
+  print({k: v for k, v in download_statuses if v.downloaded_bytes > 0})
+
+async def test_helpers():
+  print(get_model_id("mlx-community/Llama-3.3-70B-Instruct-4bit"))
+  print(get_model_id("fjsekljfd"))
+
+if __name__ == "__main__":
+  asyncio.run(test_helpers())
+

+ 2 - 2
exo/download/shard_download.py

@@ -27,7 +27,7 @@ class ShardDownloader(ABC):
     pass
     pass
 
 
   @abstractmethod
   @abstractmethod
-  async def get_shard_download_status(self) -> Optional[Dict[str, float]]:
+  async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
     """Get the download status of shards.
     """Get the download status of shards.
     
     
     Returns:
     Returns:
@@ -45,5 +45,5 @@ class NoopShardDownloader(ShardDownloader):
   def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
   def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
     return AsyncCallbackSystem()
     return AsyncCallbackSystem()
 
 
-  async def get_shard_download_status(self) -> Optional[Dict[str, float]]:
+  async def get_shard_download_status(self, inference_engine_name: str) -> Optional[Dict[str, float]]:
     return None
     return None

+ 3 - 4
exo/main.py

@@ -24,7 +24,7 @@ from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 from exo.api import ChatGPTAPI
 from exo.api import ChatGPTAPI
 from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
 from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
-from exo.download.hf.hf_shard_download import HFShardDownloader
+from exo.download.hf.new_shard_download import new_shard_downloader
 from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses_and_interfaces, terminal_link, shutdown
 from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses_and_interfaces, terminal_link, shutdown
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
@@ -117,8 +117,7 @@ print_yellow_exo()
 system_info = get_system_info()
 system_info = get_system_info()
 print(f"Detected system: {system_info}")
 print(f"Detected system: {system_info}")
 
 
-shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check,
-                                                      max_parallel_downloads=args.max_parallel_downloads) if args.inference_engine != "dummy" else NoopShardDownloader()
+shard_downloader: ShardDownloader = new_shard_downloader() if args.inference_engine != "dummy" else NoopShardDownloader()
 inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
 inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
 print(f"Inference engine name after selection: {inference_engine_name}")
 print(f"Inference engine name after selection: {inference_engine_name}")
 
 
@@ -175,10 +174,10 @@ node = Node(
   None,
   None,
   inference_engine,
   inference_engine,
   discovery,
   discovery,
+  shard_downloader,
   partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
   partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
   max_generate_tokens=args.max_generate_tokens,
   max_generate_tokens=args.max_generate_tokens,
   topology_viz=topology_viz,
   topology_viz=topology_viz,
-  shard_downloader=shard_downloader,
   default_sample_temperature=args.default_temp
   default_sample_temperature=args.default_temp
 )
 )
 server = GRPCServer(node, args.node_host, args.node_port)
 server = GRPCServer(node, args.node_host, args.node_port)

+ 10 - 1
exo/models.py

@@ -175,8 +175,11 @@ pretty_name = {
   "deepseek-coder-v2-lite": "Deepseek Coder V2 Lite",
   "deepseek-coder-v2-lite": "Deepseek Coder V2 Lite",
   "deepseek-coder-v2.5": "Deepseek Coder V2.5",
   "deepseek-coder-v2.5": "Deepseek Coder V2.5",
   "deepseek-v3": "Deepseek V3",
   "deepseek-v3": "Deepseek V3",
+  "deepseek-v3-3bit": "Deepseek V3 (3-bit)",
   "deepseek-r1": "Deepseek R1",
   "deepseek-r1": "Deepseek R1",
+  "deepseek-r1-3bit": "Deepseek R1 (3-bit)",
   "llava-1.5-7b-hf": "LLaVa 1.5 7B (Vision Model)",
   "llava-1.5-7b-hf": "LLaVa 1.5 7B (Vision Model)",
+  "qwen-2.5-0.5b": "Qwen 2.5 0.5B",
   "qwen-2.5-1.5b": "Qwen 2.5 1.5B",
   "qwen-2.5-1.5b": "Qwen 2.5 1.5B",
   "qwen-2.5-coder-1.5b": "Qwen 2.5 Coder 1.5B",
   "qwen-2.5-coder-1.5b": "Qwen 2.5 Coder 1.5B",
   "qwen-2.5-3b": "Qwen 2.5 3B",
   "qwen-2.5-3b": "Qwen 2.5 3B",
@@ -232,6 +235,12 @@ pretty_name = {
 def get_repo(model_id: str, inference_engine_classname: str) -> Optional[str]:
 def get_repo(model_id: str, inference_engine_classname: str) -> Optional[str]:
   return model_cards.get(model_id, {}).get("repo", {}).get(inference_engine_classname, None)
   return model_cards.get(model_id, {}).get("repo", {}).get(inference_engine_classname, None)
 
 
+def get_model_id(repo_id: str) -> Optional[str]:
+  return next((model_id for model_id, card in model_cards.items() if repo_id in card["repo"].values()), None)
+
+def get_pretty_name(model_id: str) -> Optional[str]:
+  return pretty_name.get(model_id, None)
+
 def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional[Shard]:
 def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional[Shard]:
   repo = get_repo(model_id, inference_engine_classname)
   repo = get_repo(model_id, inference_engine_classname)
   n_layers = model_cards.get(model_id, {}).get("layers", 0)
   n_layers = model_cards.get(model_id, {}).get("layers", 0)
@@ -239,7 +248,7 @@ def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional
     return None
     return None
   return Shard(model_id, 0, 0, n_layers)
   return Shard(model_id, 0, 0, n_layers)
 
 
-def get_supported_models(supported_inference_engine_lists: List[List[str]]) -> List[str]:
+def get_supported_models(supported_inference_engine_lists: Optional[List[List[str]]] = None) -> List[str]:
   if not supported_inference_engine_lists:
   if not supported_inference_engine_lists:
     return list(model_cards.keys())
     return list(model_cards.keys())
 
 

+ 3 - 3
exo/orchestration/node.py

@@ -15,7 +15,7 @@ from exo.helpers import AsyncCallbackSystem
 from exo.viz.topology_viz import TopologyViz
 from exo.viz.topology_viz import TopologyViz
 from exo.download.hf.hf_helpers import RepoProgressEvent
 from exo.download.hf.hf_helpers import RepoProgressEvent
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
-from exo.download.hf.hf_shard_download import HFShardDownloader
+from exo.download.shard_download import ShardDownloader
 
 
 class Node:
 class Node:
   def __init__(
   def __init__(
@@ -24,16 +24,17 @@ class Node:
     server: Server,
     server: Server,
     inference_engine: InferenceEngine,
     inference_engine: InferenceEngine,
     discovery: Discovery,
     discovery: Discovery,
+    shard_downloader: ShardDownloader,
     partitioning_strategy: PartitioningStrategy = None,
     partitioning_strategy: PartitioningStrategy = None,
     max_generate_tokens: int = 1024,
     max_generate_tokens: int = 1024,
     default_sample_temperature: float = 0.0,
     default_sample_temperature: float = 0.0,
     topology_viz: Optional[TopologyViz] = None,
     topology_viz: Optional[TopologyViz] = None,
-    shard_downloader: Optional[HFShardDownloader] = None,
   ):
   ):
     self.id = _id
     self.id = _id
     self.inference_engine = inference_engine
     self.inference_engine = inference_engine
     self.server = server
     self.server = server
     self.discovery = discovery
     self.discovery = discovery
+    self.shard_downloader = shard_downloader
     self.partitioning_strategy = partitioning_strategy
     self.partitioning_strategy = partitioning_strategy
     self.peers: List[PeerHandle] = {}
     self.peers: List[PeerHandle] = {}
     self.topology: Topology = Topology()
     self.topology: Topology = Topology()
@@ -52,7 +53,6 @@ class Node:
     self._on_opaque_status.register("node_status").on_next(self.on_node_status)
     self._on_opaque_status.register("node_status").on_next(self.on_node_status)
     self.node_download_progress: Dict[str, RepoProgressEvent] = {}
     self.node_download_progress: Dict[str, RepoProgressEvent] = {}
     self.topology_inference_engines_pool: List[List[str]] = []
     self.topology_inference_engines_pool: List[List[str]] = []
-    self.shard_downloader = shard_downloader
     self.outstanding_requests = {}
     self.outstanding_requests = {}
 
 
   async def start(self, wait_for_peers: int = 0) -> None:
   async def start(self, wait_for_peers: int = 0) -> None:

+ 2 - 2
exo/orchestration/test_node.py

@@ -5,7 +5,7 @@ import pytest
 
 
 from .node import Node
 from .node import Node
 from exo.networking.peer_handle import PeerHandle
 from exo.networking.peer_handle import PeerHandle
-
+from exo.download.shard_download import NoopShardDownloader
 
 
 class TestNode(unittest.IsolatedAsyncioTestCase):
 class TestNode(unittest.IsolatedAsyncioTestCase):
   def setUp(self):
   def setUp(self):
@@ -22,7 +22,7 @@ class TestNode(unittest.IsolatedAsyncioTestCase):
     mock_peer2.id.return_value = "peer2"
     mock_peer2.id.return_value = "peer2"
     self.mock_discovery.discover_peers = AsyncMock(return_value=[mock_peer1, mock_peer2])
     self.mock_discovery.discover_peers = AsyncMock(return_value=[mock_peer1, mock_peer2])
 
 
-    self.node = Node("test_node", self.mock_server, self.mock_inference_engine, "localhost", 50051, self.mock_discovery)
+    self.node = Node("test_node", self.mock_server, self.mock_inference_engine, "localhost", 50051, self.mock_discovery, NoopShardDownloader())
 
 
   async def asyncSetUp(self):
   async def asyncSetUp(self):
     await self.node.start()
     await self.node.start()

+ 3 - 3
exo/tinychat/index.js

@@ -75,12 +75,12 @@ document.addEventListener("alpine:init", () => {
       while (true) {
       while (true) {
         try {
         try {
           await this.populateSelector();
           await this.populateSelector();
-          // Wait 5 seconds before next poll
-          await new Promise(resolve => setTimeout(resolve, 5000));
+          // Wait 15 seconds before next poll
+          await new Promise(resolve => setTimeout(resolve, 15000));
         } catch (error) {
         } catch (error) {
           console.error('Model polling error:', error);
           console.error('Model polling error:', error);
           // If there's an error, wait before retrying
           // If there's an error, wait before retrying
-          await new Promise(resolve => setTimeout(resolve, 5000));
+          await new Promise(resolve => setTimeout(resolve, 15000));
         }
         }
       }
       }
     },
     },