1
0
Эх сурвалжийг харах

async model downloading with download progress. fixes #102. related: #16 #104

Alex Cheema 9 сар өмнө
parent
commit
d6a7e46324

+ 1 - 2
exo/helpers.py

@@ -1,7 +1,6 @@
 import os
 import asyncio
-from typing import Any, Callable, TypeVar, Optional, Dict, Generic, Tuple, List
-from collections import defaultdict
+from typing import Callable, TypeVar, Optional, Dict, Generic, Tuple, List
 import socket
 import random
 import platform

+ 7 - 3
exo/inference/mlx/sharded_inference_engine.py

@@ -4,12 +4,13 @@ from ..inference_engine import InferenceEngine
 from .sharded_model import StatefulShardedModel
 from .sharded_utils import load_shard, get_image_from_str
 from ..shard import Shard
-from typing import Optional
+from typing import Optional, Callable
 
 
 class MLXDynamicShardInferenceEngine(InferenceEngine):
-  def __init__(self):
+  def __init__(self, on_download_progress: Callable[[int, int], None] = None):
     self.shard = None
+    self.on_download_progress = on_download_progress
 
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     await self.ensure_shard(shard)
@@ -32,6 +33,9 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     if self.shard == shard:
       return
 
-    model_shard, self.tokenizer = await load_shard(shard.model_id, shard)
+    model_shard, self.tokenizer = await load_shard(shard.model_id, shard, on_download_progress=self.on_download_progress)
     self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
     self.shard = shard
+
+  def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]):
+    self.on_download_progress = on_download_progress

+ 54 - 11
exo/inference/mlx/sharded_utils.py

@@ -8,16 +8,19 @@ import asyncio
 import aiohttp
 from functools import partial
 from pathlib import Path
-from typing import Optional, Tuple
-import requests
+from typing import Optional, Tuple, Union, List, Callable
 from PIL import Image
 from io import BytesIO
 import base64
+import os
 
 from exo import DEBUG
 import mlx.core as mx
 import mlx.nn as nn
-from huggingface_hub import snapshot_download
+from huggingface_hub import snapshot_download, list_repo_tree, get_paths_info
+from huggingface_hub.utils import filter_repo_objects
+from huggingface_hub.file_download import repo_folder_name
+from huggingface_hub.constants import HF_HUB_CACHE
 from huggingface_hub.utils._errors import RepositoryNotFoundError
 from transformers import AutoProcessor
 
@@ -144,12 +147,50 @@ def load_model_shard(
   return model
 
 
-async def snapshot_download_async(*args, **kwargs):
-  func = partial(snapshot_download, *args, **kwargs)
-  return await asyncio.get_event_loop().run_in_executor(None, func)
-
-
-async def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
+async def get_repo_size(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None):
+  it = await asyncio.to_thread(list_repo_tree, repo_id, revision=revision, repo_type=repo_type)
+  files = list(filter_repo_objects(it, allow_patterns=allow_patterns, key=lambda f: f.path))
+  return sum(file.size for file in files if file.size is not None)
+
+async def monitor_progress(dir, total_size, print_progress=False, on_progress: Callable[[int, int], None] = None):
+    while True:
+        await asyncio.sleep(0.1)
+        current_size = sum(os.path.getsize(os.path.join(root, file))
+                           for root, _, files in os.walk(dir)
+                           for file in files)
+        progress = min(current_size / total_size * 100, 100)
+        if print_progress:
+          print(f"\rProgress: {progress:.2f}% ({current_size}/{total_size} bytes)", end="", flush=True)
+        if on_progress:
+          on_progress(current_size, total_size)
+        if progress >= 100:
+          if print_progress:
+            print("\nDownload complete!")
+          break
+
+async def download_repo(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None):
+  # Use snapshot_download in a separate thread to not block the event loop
+  return await asyncio.to_thread(snapshot_download, repo_id=repo_id, revision=revision, allow_patterns=allow_patterns, repo_type=repo_type)
+
+async def download_async_with_progress(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None, on_progress: Callable[[int, int], None] = None):
+  storage_folder = os.path.join(HF_HUB_CACHE, repo_folder_name(repo_id=repo_id, repo_type="model"))
+  # os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'
+  # os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
+
+  print(f"Estimating size of repository: {repo_id}")
+  total_size = await get_repo_size(repo_id)
+  print(f"Estimated total size: {total_size} bytes")
+
+  # Create tasks for download and progress checking
+  download_task = asyncio.create_task(download_repo(repo_id, revision=revision, allow_patterns=allow_patterns, repo_type=repo_type))
+  progress_task = asyncio.create_task(monitor_progress(storage_folder, total_size, on_progress=on_progress))
+
+  # Wait for both tasks to complete
+  result = await asyncio.gather(download_task, progress_task)
+  return result[0]  # Return the result from download_task
+
+
+async def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None, on_download_progress: Callable[[int, int], None] = None) -> Path:
   """
   Ensures the model is available locally. If the path does not exist locally,
   it is downloaded from the Hugging Face Hub.
@@ -165,7 +206,7 @@ async def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -
   if not model_path.exists():
     try:
       model_path = Path(
-        await snapshot_download_async(
+        await download_async_with_progress(
           repo_id=path_or_hf_repo,
           revision=revision,
           allow_patterns=[
@@ -176,6 +217,7 @@ async def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -
             "*.tiktoken",
             "*.txt",
           ],
+          on_progress=on_download_progress,
         )
       )
     except RepositoryNotFoundError:
@@ -196,6 +238,7 @@ async def load_shard(
   model_config={},
   adapter_path: Optional[str] = None,
   lazy: bool = False,
+  on_download_progress: Callable[[int, int], None] = None,
 ) -> Tuple[nn.Module, TokenizerWrapper]:
   """
   Load the model and tokenizer from a given path or a huggingface repository.
@@ -218,7 +261,7 @@ async def load_shard(
    FileNotFoundError: If config file or safetensors are not found.
    ValueError: If model class or args class are not found.
   """
-  model_path = await get_model_path(path_or_hf_repo)
+  model_path = await get_model_path(path_or_hf_repo, on_download_progress=on_download_progress)
 
   model = load_model_shard(model_path, shard, lazy, model_config)
   if adapter_path is not None:

+ 7 - 1
exo/orchestration/standard_node.py

@@ -52,8 +52,13 @@ class StandardNode(Node):
         elif status_data.get("status", "").startswith("end_"):
           if status_data.get("node_id") == self.current_topology.active_node_id:
             self.current_topology.active_node_id = None
+      download_progress = None
+      if status_data.get("type", "") == "download_progress":
+        if DEBUG >= 5: print(f"Download progress from {status_data.get('node_id')}: {status_data.get('current')}/{status_data.get('total')} ({round(status_data.get('current') / status_data.get('total') * 100, 2)}%)")
+        if status_data.get("node_id") == self.id:
+          download_progress = (status_data.get('current'), status_data.get('total'))
       if self.topology_viz:
-        self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology))
+        self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), download_progress)
     except json.JSONDecodeError:
       pass
 
@@ -370,6 +375,7 @@ class StandardNode(Node):
     await asyncio.gather(*[send_result_to_peer(peer) for peer in self.peers], return_exceptions=True)
 
   async def broadcast_opaque_status(self, request_id: str, status: str) -> None:
+    if DEBUG >= 5: print(f"Broadcasting opaque status: {request_id=} {status=}")
     async def send_status_to_peer(peer):
       try:
         await asyncio.wait_for(peer.send_opaque_status(request_id, status), timeout=15.0)

+ 5 - 3
exo/viz/topology_viz.py

@@ -1,5 +1,5 @@
 import math
-from typing import List
+from typing import List, Optional, Tuple
 from exo.helpers import exo_text
 from exo.topology.topology import Topology
 from exo.topology.partitioning_strategy import Partition
@@ -17,22 +17,24 @@ class TopologyViz:
     self.web_chat_url = web_chat_url
     self.topology = Topology()
     self.partitions: List[Partition] = []
+    self.download_progress = None
 
     self.console = Console()
     self.panel = Panel(self._generate_layout(), title="Exo Cluster (0 nodes)", border_style="bright_yellow")
     self.live_panel = Live(self.panel, auto_refresh=False, console=self.console)
     self.live_panel.start()
 
-  def update_visualization(self, topology: Topology, partitions: List[Partition]):
+  def update_visualization(self, topology: Topology, partitions: List[Partition], download_progress: Optional[Tuple[int, int]] = None):
     self.topology = topology
     self.partitions = partitions
+    self.download_progress = download_progress
     self.refresh()
 
   def refresh(self):
     self.panel.renderable = self._generate_layout()
     # Update the panel title with the number of nodes and partitions
     node_count = len(self.topology.nodes)
-    self.panel.title = f"Exo Cluster ({node_count} node{'s' if node_count != 1 else ''})"
+    self.panel.title = f"Exo Cluster ({node_count} node{'s' if node_count != 1 else ''}){f' {self.download_progress[0]/self.download_progress[1]:.2%} Downloaded' if self.download_progress else ''}"
     self.live_panel.update(self.panel, refresh=True)
 
   def _generate_layout(self) -> str:

+ 2 - 0
main.py

@@ -1,6 +1,7 @@
 import argparse
 import asyncio
 import signal
+import json
 import uuid
 from exo.orchestration.standard_node import StandardNode
 from exo.networking.grpc.grpc_server import GRPCServer
@@ -58,6 +59,7 @@ node.on_token.register("main_log").on_next(lambda _, tokens, __: print(inference
 if args.prometheus_client_port:
     from exo.stats.metrics import start_metrics_server
     start_metrics_server(node, args.prometheus_client_port)
+inference_engine.set_on_download_progress(lambda current, total: asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "current": current, "total": total}))))
 
 async def shutdown(signal, loop):
     """Gracefully shutdown the server and close the asyncio loop."""

+ 2 - 1
setup.py

@@ -9,7 +9,8 @@ install_requires = [
     "blobfile==2.1.1",
     "grpcio==1.64.1",
     "grpcio-tools==1.64.1",
-    "huggingface-hub==0.23.4",
+    "hf-transfer==0.1.8",
+    "huggingface-hub==0.24.5",
     "Jinja2==3.1.4",
     "numpy==2.0.0",
     "pillow==10.4.0",