Sfoglia il codice sorgente

rewrite ShardDownloader, simplify significantly

Alex Cheema 3 mesi fa
parent
commit
b89495f444

+ 11 - 40
exo/api/chatgpt_api.py

@@ -11,17 +11,17 @@ import aiohttp_cors
 import traceback
 import signal
 from exo import DEBUG, VERSION
-from exo.download.download_progress import RepoProgressEvent
 from exo.helpers import PrefixDict, shutdown, get_exo_images_dir
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration import Node
-from exo.models import build_base_shard, model_cards, get_repo, pretty_name
+from exo.models import build_base_shard, model_cards, get_repo, get_model_id, get_supported_models, get_pretty_name
 from typing import Callable, Optional
 from PIL import Image
 import numpy as np
 import base64
 from io import BytesIO
 import platform
+from exo.download.shard_download import RepoProgressEvent
 
 if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
   import mlx.core as mx
@@ -29,7 +29,6 @@ else:
   import numpy as mx
 
 import tempfile
-from exo.download.hf.hf_shard_download import HFShardDownloader
 import shutil
 from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
 from exo.apputil import create_animation_mp4
@@ -277,41 +276,12 @@ class ChatGPTAPI:
 
   async def handle_model_support(self, request):
     try:
-      response = web.StreamResponse(status=200, reason='OK', headers={
-        'Content-Type': 'text/event-stream',
-        'Cache-Control': 'no-cache',
-        'Connection': 'keep-alive',
-      })
+      response = web.StreamResponse(status=200, reason='OK', headers={ 'Content-Type': 'text/event-stream', 'Cache-Control': 'no-cache', 'Connection': 'keep-alive' })
       await response.prepare(request)
-
-      async def process_model(model_name, pretty):
-        if model_name in model_cards:
-          model_info = model_cards[model_name]
-
-          if self.inference_engine_classname in model_info.get("repo", {}):
-            shard = build_base_shard(model_name, self.inference_engine_classname)
-            if shard:
-              downloader = HFShardDownloader(quick_check=True)
-              downloader.current_shard = shard
-              downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
-              status = await downloader.get_shard_download_status()
-
-              download_percentage = status.get("overall") if status else None
-              total_size = status.get("total_size") if status else None
-              total_downloaded = status.get("total_downloaded") if status else False
-
-              model_data = {
-                model_name: {
-                  "name": pretty, "downloaded": download_percentage == 100 if download_percentage is not None else False, "download_percentage": download_percentage, "total_size": total_size,
-                  "total_downloaded": total_downloaded
-                }
-              }
-
-              await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
-
-      # Process all models in parallel
-      await asyncio.gather(*[process_model(model_name, pretty) for model_name, pretty in pretty_name.items()])
-
+      downloads = await self.node.shard_downloader.get_shard_download_status(self.inference_engine_classname)
+      for (path, d) in downloads:
+        model_data = { get_model_id(d.repo_id): { "downloaded": d.downloaded_bytes == d.total_bytes, "download_percentage": 100 if d.downloaded_bytes == d.total_bytes else 100 * float(d.downloaded_bytes) / float(d.total_bytes), "total_size": d.total_bytes, "total_downloaded": d.downloaded_bytes } }
+        await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
       await response.write(b"data: [DONE]\n\n")
       return response
 
@@ -348,6 +318,7 @@ class ChatGPTAPI:
     progress_data = {}
     for node_id, progress_event in self.node.node_download_progress.items():
       if isinstance(progress_event, RepoProgressEvent):
+        if progress_event.status != "in_progress": continue
         progress_data[node_id] = progress_event.to_dict()
       else:
         print(f"Unknown progress event type: {type(progress_event)}. {progress_event}")
@@ -611,9 +582,9 @@ class ChatGPTAPI:
 
   async def handle_get_initial_models(self, request):
     model_data = {}
-    for model_name, pretty in pretty_name.items():
-      model_data[model_name] = {
-        "name": pretty,
+    for model_id in get_supported_models([[self.inference_engine_classname]]):
+      model_data[model_id] = {
+        "name": get_pretty_name(model_id),
         "downloaded": None,  # Initially unknown
         "download_percentage": None,  # Change from 0 to null
         "total_size": None,

+ 2 - 1
exo/download/download_progress.py

@@ -14,11 +14,12 @@ class RepoFileProgressEvent:
   speed: int
   eta: timedelta
   status: Literal["not_started", "in_progress", "complete"]
+  start_time: float
 
   def to_dict(self):
     return {
       "repo_id": self.repo_id, "repo_revision": self.repo_revision, "file_path": self.file_path, "downloaded": self.downloaded, "downloaded_this_session": self.downloaded_this_session,
-      "total": self.total, "speed": self.speed, "eta": self.eta.total_seconds(), "status": self.status
+      "total": self.total, "speed": self.speed, "eta": self.eta.total_seconds(), "status": self.status, "start_time": self.start_time
     }
 
   @classmethod

+ 0 - 1
exo/download/hf/hf_helpers.py

@@ -209,7 +209,6 @@ async def download_file(
       raise aiohttp.ClientResponseError(response.request_info, response.history, status=response.status, message=f"Failed to download {file_path}: {response.status}")
 
     if downloaded_size == total_size:
-      print(f"File already downloaded: {file_path}")
       if progress_callback:
         await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
       return

+ 180 - 0
exo/download/hf/new_shard_download.py

@@ -0,0 +1,180 @@
+from exo.inference.shard import Shard
+from exo.models import get_repo
+from pathlib import Path
+from exo.download.hf.hf_helpers import get_hf_endpoint, get_auth_headers, filter_repo_objects, get_allow_patterns
+from exo.download.shard_download import ShardDownloader
+from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent
+from exo.helpers import AsyncCallbackSystem, DEBUG
+from exo.models import get_supported_models, build_base_shard
+import os
+import aiofiles.os as aios
+import aiohttp
+import aiofiles
+from urllib.parse import urljoin
+from typing import Optional, Callable, Union, Tuple, Dict
+import time
+from datetime import timedelta
+import asyncio
+import json
+
+def exo_home() -> Path:
+  return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo"))
+
+async def ensure_downloads_dir() -> Path:
+  downloads_dir = exo_home()/"downloads"
+  await aios.makedirs(downloads_dir, exist_ok=True)
+  return downloads_dir
+
+async def fetch_file_list(session, repo_id, revision, path=""):
+  api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
+  url = f"{api_url}/{path}" if path else api_url
+
+  headers = await get_auth_headers()
+  async with session.get(url, headers=headers) as response:
+    if response.status == 200:
+      data = await response.json()
+      files = []
+      for item in data:
+        if item["type"] == "file":
+          files.append({"path": item["path"], "size": item["size"]})
+        elif item["type"] == "directory":
+          subfiles = await fetch_file_list(session, repo_id, revision, item["path"])
+          files.extend(subfiles)
+      return files
+    else:
+      raise Exception(f"Failed to fetch file list: {response.status}")
+
+async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Optional[Callable[[int, int], None]] = None) -> Path:
+  if (target_dir/path).exists(): return target_dir/path
+  base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
+  url = urljoin(base_url, path)
+  headers = await get_auth_headers()
+  async with session.get(url, headers=headers) as r:
+    assert r.status == 200, r.status
+    length = int(r.headers.get('content-length', 0))
+    n_read = 0
+    async with aiofiles.tempfile.NamedTemporaryFile(dir=target_dir, delete=False) as temp_file:
+      while chunk := await r.content.read(16384): on_progress(n_read := n_read + await temp_file.write(chunk), length)
+      await aios.rename(temp_file.name, target_dir/path)
+    return target_dir/path
+
+def calculate_repo_progress(repo_id: str, revision: str, file_progress: Dict[str, RepoFileProgressEvent], all_start_time: float) -> RepoProgressEvent:
+  all_total_bytes = sum([p.total for p in file_progress.values()])
+  all_downloaded_bytes = sum([p.downloaded for p in file_progress.values()])
+  elapsed_time = time.time() - all_start_time
+  all_speed = all_downloaded_bytes / elapsed_time if elapsed_time > 0 else 0
+  all_eta = timedelta(seconds=(all_total_bytes - all_downloaded_bytes) / all_speed) if all_speed > 0 else timedelta(seconds=0)
+  status = "not_started" if all_downloaded_bytes == 0 else "complete" if all_downloaded_bytes == all_total_bytes else "in_progress"
+  return RepoProgressEvent(repo_id, revision, len([p for p in file_progress.values() if p.downloaded == p.total]), len(file_progress), all_downloaded_bytes, all_downloaded_bytes, all_total_bytes, all_speed, all_eta, file_progress, status)
+
+async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]:
+  target_dir = await ensure_downloads_dir()/repo_id.replace("/", "--")
+  async with aiohttp.ClientSession() as session:
+    index_file = await download_file(session, repo_id, revision, "model.safetensors.index.json", target_dir)
+    async with aiofiles.open(index_file, 'r') as f:
+      index_data = json.loads(await f.read())
+    return index_data.get("weight_map")
+
+async def resolve_allow_patterns(shard: Shard, inference_engine_classname: str) -> list[str]:
+  weight_map = await get_weight_map(get_repo(shard.model_id, inference_engine_classname))
+  return get_allow_patterns(weight_map, shard)
+
+async def download_shard(shard: Shard, inference_engine_classname: str, on_progress: AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]], max_parallel_downloads: int = 6, skip_download: bool = False) -> tuple[Path, RepoProgressEvent]:
+  if DEBUG >= 2 and not skip_download: print(f"Downloading {shard.model_id=} for {inference_engine_classname}")
+  repo_id = get_repo(shard.model_id, inference_engine_classname)
+  revision = "main"
+  target_dir = await ensure_downloads_dir()/repo_id.replace("/", "--")
+  if not skip_download: await aios.makedirs(target_dir, exist_ok=True)
+
+  if repo_id is None:
+    raise ValueError(f"No repo found for {shard.model_id=} and inference engine {inference_engine_classname}")
+
+  allow_patterns = await resolve_allow_patterns(shard, inference_engine_classname) if not skip_download else None
+  if DEBUG >= 3: print(f"Downloading {shard.model_id=} with {allow_patterns=}")
+
+  all_start_time = time.time()
+  async with aiohttp.ClientSession() as session:
+    file_list = await fetch_file_list(session, repo_id, revision)
+    filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, key=lambda x: x["path"]))
+    file_progress: Dict[str, RepoFileProgressEvent] = {}
+    def on_progress_wrapper(file: dict, curr_bytes: int, total_bytes: int):
+      start_time = file_progress[file["path"]].start_time if file["path"] in file_progress else time.time()
+      speed = curr_bytes / (time.time() - start_time)
+      eta = timedelta(seconds=(total_bytes - curr_bytes) / speed)
+      file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], curr_bytes, curr_bytes, total_bytes, speed, eta, "in_progress", start_time)
+      on_progress.trigger_all(shard, calculate_repo_progress(repo_id, revision, file_progress, all_start_time))
+      if DEBUG >= 2: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}")
+    for file in filtered_file_list:
+      downloaded_bytes = (await aios.stat(target_dir/file["path"])).st_size if await aios.path.exists(target_dir/file["path"]) else 0
+      file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], downloaded_bytes, 0, file["size"], 0, timedelta(0), "not_started" if downloaded_bytes == 0 else "complete" if downloaded_bytes == file["size"] else "in_progress", time.time())
+
+    semaphore = asyncio.Semaphore(max_parallel_downloads)
+    async def download_with_semaphore(file):
+      async with semaphore:
+        await download_file(session, repo_id, revision, file["path"], target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes))
+    if not skip_download: await asyncio.gather(*[download_with_semaphore(file) for file in filtered_file_list])
+    final_repo_progress = calculate_repo_progress(repo_id, revision, file_progress, all_start_time)
+    on_progress.trigger_all(shard, final_repo_progress)
+    return target_dir, final_repo_progress
+
+def new_shard_downloader() -> ShardDownloader:
+  return SingletonShardDownloader(CachedShardDownloader(NewShardDownloader()))
+
+class SingletonShardDownloader(ShardDownloader):
+  def __init__(self, shard_downloader: ShardDownloader):
+    self.shard_downloader = shard_downloader
+    self.active_downloads: Dict[Shard, asyncio.Task] = {}
+
+  @property
+  def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
+    return self.shard_downloader.on_progress
+
+  async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
+    if shard not in self.active_downloads: self.active_downloads[shard] = asyncio.create_task(self.shard_downloader.ensure_shard(shard, inference_engine_name))
+    try: return await self.active_downloads[shard]
+    finally:
+      if shard in self.active_downloads and self.active_downloads[shard].done(): del self.active_downloads[shard]
+
+  async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
+    return await self.shard_downloader.get_shard_download_status(inference_engine_name)
+
+class CachedShardDownloader(ShardDownloader):
+  def __init__(self, shard_downloader: ShardDownloader):
+    self.shard_downloader = shard_downloader
+    self.cache: Dict[tuple[str, Shard], Path] = {}
+
+  @property
+  def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
+    return self.shard_downloader.on_progress
+
+  async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
+    if (inference_engine_name, shard) in self.cache:
+      if DEBUG >= 2: print(f"ensure_shard cache hit {shard=} for {inference_engine_name}")
+      return self.cache[(inference_engine_name, shard)]
+    if DEBUG >= 2: print(f"ensure_shard cache miss {shard=} for {inference_engine_name}")
+    target_dir = await self.shard_downloader.ensure_shard(shard, inference_engine_name)
+    self.cache[(inference_engine_name, shard)] = target_dir
+    return target_dir
+
+  async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
+    return await self.shard_downloader.get_shard_download_status(inference_engine_name)
+
+class NewShardDownloader(ShardDownloader):
+  def __init__(self):
+    self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
+
+  @property
+  def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
+    return self._on_progress
+
+  async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
+    target_dir, _ = await download_shard(shard, inference_engine_name, self.on_progress)
+    return target_dir
+
+  async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
+    if DEBUG >= 2: print("Getting shard download status for", inference_engine_name)
+    downloads = await asyncio.gather(*[download_shard(build_base_shard(model_id, inference_engine_name), inference_engine_name, self.on_progress, skip_download=True) for model_id in get_supported_models([[inference_engine_name]])], return_exceptions=True)
+    if DEBUG >= 6: print("Downloaded shards:", downloads)
+    if any(isinstance(d, Exception) for d in downloads) and DEBUG >= 1: print("Error downloading shards:", [d for d in downloads if isinstance(d, Exception)])
+    return [d for d in downloads if not isinstance(d, Exception)]
+

+ 20 - 0
exo/download/hf/test_new_shard_download.py

@@ -0,0 +1,20 @@
+from exo.download.hf.new_shard_download import download_shard, NewShardDownloader
+from exo.inference.shard import Shard
+from exo.models import get_model_id
+from pathlib import Path
+import asyncio
+
+async def test_new_shard_download():
+  shard_downloader = NewShardDownloader()
+  shard_downloader.on_progress.register("test").on_next(lambda shard, event: print(shard, event))
+  await shard_downloader.ensure_shard(Shard(model_id="llama-3.2-1b", start_layer=0, end_layer=0, n_layers=16), "MLXDynamicShardInferenceEngine")
+  download_statuses = await shard_downloader.get_shard_download_status("MLXDynamicShardInferenceEngine")
+  print({k: v for k, v in download_statuses if v.downloaded_bytes > 0})
+
+async def test_helpers():
+  print(get_model_id("mlx-community/Llama-3.3-70B-Instruct-4bit"))
+  print(get_model_id("fjsekljfd"))
+
+if __name__ == "__main__":
+  asyncio.run(test_helpers())
+

+ 2 - 2
exo/download/shard_download.py

@@ -27,7 +27,7 @@ class ShardDownloader(ABC):
     pass
 
   @abstractmethod
-  async def get_shard_download_status(self) -> Optional[Dict[str, float]]:
+  async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
     """Get the download status of shards.
     
     Returns:
@@ -45,5 +45,5 @@ class NoopShardDownloader(ShardDownloader):
   def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
     return AsyncCallbackSystem()
 
-  async def get_shard_download_status(self) -> Optional[Dict[str, float]]:
+  async def get_shard_download_status(self, inference_engine_name: str) -> Optional[Dict[str, float]]:
     return None

+ 3 - 4
exo/main.py

@@ -24,7 +24,7 @@ from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 from exo.api import ChatGPTAPI
 from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
-from exo.download.hf.hf_shard_download import HFShardDownloader
+from exo.download.hf.new_shard_download import new_shard_downloader
 from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses_and_interfaces, terminal_link, shutdown
 from exo.inference.shard import Shard
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
@@ -117,8 +117,7 @@ print_yellow_exo()
 system_info = get_system_info()
 print(f"Detected system: {system_info}")
 
-shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check,
-                                                      max_parallel_downloads=args.max_parallel_downloads) if args.inference_engine != "dummy" else NoopShardDownloader()
+shard_downloader: ShardDownloader = new_shard_downloader() if args.inference_engine != "dummy" else NoopShardDownloader()
 inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
 print(f"Inference engine name after selection: {inference_engine_name}")
 
@@ -175,10 +174,10 @@ node = Node(
   None,
   inference_engine,
   discovery,
+  shard_downloader,
   partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
   max_generate_tokens=args.max_generate_tokens,
   topology_viz=topology_viz,
-  shard_downloader=shard_downloader,
   default_sample_temperature=args.default_temp
 )
 server = GRPCServer(node, args.node_host, args.node_port)

+ 10 - 1
exo/models.py

@@ -175,8 +175,11 @@ pretty_name = {
   "deepseek-coder-v2-lite": "Deepseek Coder V2 Lite",
   "deepseek-coder-v2.5": "Deepseek Coder V2.5",
   "deepseek-v3": "Deepseek V3",
+  "deepseek-v3-3bit": "Deepseek V3 (3-bit)",
   "deepseek-r1": "Deepseek R1",
+  "deepseek-r1-3bit": "Deepseek R1 (3-bit)",
   "llava-1.5-7b-hf": "LLaVa 1.5 7B (Vision Model)",
+  "qwen-2.5-0.5b": "Qwen 2.5 0.5B",
   "qwen-2.5-1.5b": "Qwen 2.5 1.5B",
   "qwen-2.5-coder-1.5b": "Qwen 2.5 Coder 1.5B",
   "qwen-2.5-3b": "Qwen 2.5 3B",
@@ -232,6 +235,12 @@ pretty_name = {
 def get_repo(model_id: str, inference_engine_classname: str) -> Optional[str]:
   return model_cards.get(model_id, {}).get("repo", {}).get(inference_engine_classname, None)
 
+def get_model_id(repo_id: str) -> Optional[str]:
+  return next((model_id for model_id, card in model_cards.items() if repo_id in card["repo"].values()), None)
+
+def get_pretty_name(model_id: str) -> Optional[str]:
+  return pretty_name.get(model_id, None)
+
 def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional[Shard]:
   repo = get_repo(model_id, inference_engine_classname)
   n_layers = model_cards.get(model_id, {}).get("layers", 0)
@@ -239,7 +248,7 @@ def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional
     return None
   return Shard(model_id, 0, 0, n_layers)
 
-def get_supported_models(supported_inference_engine_lists: List[List[str]]) -> List[str]:
+def get_supported_models(supported_inference_engine_lists: Optional[List[List[str]]] = None) -> List[str]:
   if not supported_inference_engine_lists:
     return list(model_cards.keys())
 

+ 3 - 3
exo/orchestration/node.py

@@ -15,7 +15,7 @@ from exo.helpers import AsyncCallbackSystem
 from exo.viz.topology_viz import TopologyViz
 from exo.download.hf.hf_helpers import RepoProgressEvent
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
-from exo.download.hf.hf_shard_download import HFShardDownloader
+from exo.download.shard_download import ShardDownloader
 
 class Node:
   def __init__(
@@ -24,16 +24,17 @@ class Node:
     server: Server,
     inference_engine: InferenceEngine,
     discovery: Discovery,
+    shard_downloader: ShardDownloader,
     partitioning_strategy: PartitioningStrategy = None,
     max_generate_tokens: int = 1024,
     default_sample_temperature: float = 0.0,
     topology_viz: Optional[TopologyViz] = None,
-    shard_downloader: Optional[HFShardDownloader] = None,
   ):
     self.id = _id
     self.inference_engine = inference_engine
     self.server = server
     self.discovery = discovery
+    self.shard_downloader = shard_downloader
     self.partitioning_strategy = partitioning_strategy
     self.peers: List[PeerHandle] = {}
     self.topology: Topology = Topology()
@@ -52,7 +53,6 @@ class Node:
     self._on_opaque_status.register("node_status").on_next(self.on_node_status)
     self.node_download_progress: Dict[str, RepoProgressEvent] = {}
     self.topology_inference_engines_pool: List[List[str]] = []
-    self.shard_downloader = shard_downloader
     self.outstanding_requests = {}
 
   async def start(self, wait_for_peers: int = 0) -> None:

+ 2 - 2
exo/orchestration/test_node.py

@@ -5,7 +5,7 @@ import pytest
 
 from .node import Node
 from exo.networking.peer_handle import PeerHandle
-
+from exo.download.shard_download import NoopShardDownloader
 
 class TestNode(unittest.IsolatedAsyncioTestCase):
   def setUp(self):
@@ -22,7 +22,7 @@ class TestNode(unittest.IsolatedAsyncioTestCase):
     mock_peer2.id.return_value = "peer2"
     self.mock_discovery.discover_peers = AsyncMock(return_value=[mock_peer1, mock_peer2])
 
-    self.node = Node("test_node", self.mock_server, self.mock_inference_engine, "localhost", 50051, self.mock_discovery)
+    self.node = Node("test_node", self.mock_server, self.mock_inference_engine, "localhost", 50051, self.mock_discovery, NoopShardDownloader())
 
   async def asyncSetUp(self):
     await self.node.start()

+ 3 - 3
exo/tinychat/index.js

@@ -75,12 +75,12 @@ document.addEventListener("alpine:init", () => {
       while (true) {
         try {
           await this.populateSelector();
-          // Wait 5 seconds before next poll
-          await new Promise(resolve => setTimeout(resolve, 5000));
+          // Wait 15 seconds before next poll
+          await new Promise(resolve => setTimeout(resolve, 15000));
         } catch (error) {
           console.error('Model polling error:', error);
           // If there's an error, wait before retrying
-          await new Promise(resolve => setTimeout(resolve, 5000));
+          await new Promise(resolve => setTimeout(resolve, 15000));
         }
       }
     },