Sfoglia il codice sorgente

async model downloading with download progress. fixes #102. related: #16 #104

Alex Cheema 9 mesi fa
parent
commit
d6a7e46324

+ 1 - 2
exo/helpers.py

@@ -1,7 +1,6 @@
 import os
 import os
 import asyncio
 import asyncio
-from typing import Any, Callable, TypeVar, Optional, Dict, Generic, Tuple, List
-from collections import defaultdict
+from typing import Callable, TypeVar, Optional, Dict, Generic, Tuple, List
 import socket
 import socket
 import random
 import random
 import platform
 import platform

+ 7 - 3
exo/inference/mlx/sharded_inference_engine.py

@@ -4,12 +4,13 @@ from ..inference_engine import InferenceEngine
 from .sharded_model import StatefulShardedModel
 from .sharded_model import StatefulShardedModel
 from .sharded_utils import load_shard, get_image_from_str
 from .sharded_utils import load_shard, get_image_from_str
 from ..shard import Shard
 from ..shard import Shard
-from typing import Optional
+from typing import Optional, Callable
 
 
 
 
 class MLXDynamicShardInferenceEngine(InferenceEngine):
 class MLXDynamicShardInferenceEngine(InferenceEngine):
-  def __init__(self):
+  def __init__(self, on_download_progress: Callable[[int, int], None] = None):
     self.shard = None
     self.shard = None
+    self.on_download_progress = on_download_progress
 
 
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
@@ -32,6 +33,9 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     if self.shard == shard:
     if self.shard == shard:
       return
       return
 
 
-    model_shard, self.tokenizer = await load_shard(shard.model_id, shard)
+    model_shard, self.tokenizer = await load_shard(shard.model_id, shard, on_download_progress=self.on_download_progress)
     self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
     self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
     self.shard = shard
     self.shard = shard
+
+  def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]):
+    self.on_download_progress = on_download_progress

+ 54 - 11
exo/inference/mlx/sharded_utils.py

@@ -8,16 +8,19 @@ import asyncio
 import aiohttp
 import aiohttp
 from functools import partial
 from functools import partial
 from pathlib import Path
 from pathlib import Path
-from typing import Optional, Tuple
-import requests
+from typing import Optional, Tuple, Union, List, Callable
 from PIL import Image
 from PIL import Image
 from io import BytesIO
 from io import BytesIO
 import base64
 import base64
+import os
 
 
 from exo import DEBUG
 from exo import DEBUG
 import mlx.core as mx
 import mlx.core as mx
 import mlx.nn as nn
 import mlx.nn as nn
-from huggingface_hub import snapshot_download
+from huggingface_hub import snapshot_download, list_repo_tree, get_paths_info
+from huggingface_hub.utils import filter_repo_objects
+from huggingface_hub.file_download import repo_folder_name
+from huggingface_hub.constants import HF_HUB_CACHE
 from huggingface_hub.utils._errors import RepositoryNotFoundError
 from huggingface_hub.utils._errors import RepositoryNotFoundError
 from transformers import AutoProcessor
 from transformers import AutoProcessor
 
 
@@ -144,12 +147,50 @@ def load_model_shard(
   return model
   return model
 
 
 
 
-async def snapshot_download_async(*args, **kwargs):
-  func = partial(snapshot_download, *args, **kwargs)
-  return await asyncio.get_event_loop().run_in_executor(None, func)
-
-
-async def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
+async def get_repo_size(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None):
+  it = await asyncio.to_thread(list_repo_tree, repo_id, revision=revision, repo_type=repo_type)
+  files = list(filter_repo_objects(it, allow_patterns=allow_patterns, key=lambda f: f.path))
+  return sum(file.size for file in files if file.size is not None)
+
+async def monitor_progress(dir, total_size, print_progress=False, on_progress: Callable[[int, int], None] = None):
+    while True:
+        await asyncio.sleep(0.1)
+        current_size = sum(os.path.getsize(os.path.join(root, file))
+                           for root, _, files in os.walk(dir)
+                           for file in files)
+        progress = min(current_size / total_size * 100, 100)
+        if print_progress:
+          print(f"\rProgress: {progress:.2f}% ({current_size}/{total_size} bytes)", end="", flush=True)
+        if on_progress:
+          on_progress(current_size, total_size)
+        if progress >= 100:
+          if print_progress:
+            print("\nDownload complete!")
+          break
+
+async def download_repo(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None):
+  # Use snapshot_download in a separate thread to not block the event loop
+  return await asyncio.to_thread(snapshot_download, repo_id=repo_id, revision=revision, allow_patterns=allow_patterns, repo_type=repo_type)
+
+async def download_async_with_progress(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None, on_progress: Callable[[int, int], None] = None):
+  storage_folder = os.path.join(HF_HUB_CACHE, repo_folder_name(repo_id=repo_id, repo_type="model"))
+  # os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'
+  # os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
+
+  print(f"Estimating size of repository: {repo_id}")
+  total_size = await get_repo_size(repo_id)
+  print(f"Estimated total size: {total_size} bytes")
+
+  # Create tasks for download and progress checking
+  download_task = asyncio.create_task(download_repo(repo_id, revision=revision, allow_patterns=allow_patterns, repo_type=repo_type))
+  progress_task = asyncio.create_task(monitor_progress(storage_folder, total_size, on_progress=on_progress))
+
+  # Wait for both tasks to complete
+  result = await asyncio.gather(download_task, progress_task)
+  return result[0]  # Return the result from download_task
+
+
+async def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None, on_download_progress: Callable[[int, int], None] = None) -> Path:
   """
   """
   Ensures the model is available locally. If the path does not exist locally,
   Ensures the model is available locally. If the path does not exist locally,
   it is downloaded from the Hugging Face Hub.
   it is downloaded from the Hugging Face Hub.
@@ -165,7 +206,7 @@ async def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -
   if not model_path.exists():
   if not model_path.exists():
     try:
     try:
       model_path = Path(
       model_path = Path(
-        await snapshot_download_async(
+        await download_async_with_progress(
           repo_id=path_or_hf_repo,
           repo_id=path_or_hf_repo,
           revision=revision,
           revision=revision,
           allow_patterns=[
           allow_patterns=[
@@ -176,6 +217,7 @@ async def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -
             "*.tiktoken",
             "*.tiktoken",
             "*.txt",
             "*.txt",
           ],
           ],
+          on_progress=on_download_progress,
         )
         )
       )
       )
     except RepositoryNotFoundError:
     except RepositoryNotFoundError:
@@ -196,6 +238,7 @@ async def load_shard(
   model_config={},
   model_config={},
   adapter_path: Optional[str] = None,
   adapter_path: Optional[str] = None,
   lazy: bool = False,
   lazy: bool = False,
+  on_download_progress: Callable[[int, int], None] = None,
 ) -> Tuple[nn.Module, TokenizerWrapper]:
 ) -> Tuple[nn.Module, TokenizerWrapper]:
   """
   """
   Load the model and tokenizer from a given path or a huggingface repository.
   Load the model and tokenizer from a given path or a huggingface repository.
@@ -218,7 +261,7 @@ async def load_shard(
    FileNotFoundError: If config file or safetensors are not found.
    FileNotFoundError: If config file or safetensors are not found.
    ValueError: If model class or args class are not found.
    ValueError: If model class or args class are not found.
   """
   """
-  model_path = await get_model_path(path_or_hf_repo)
+  model_path = await get_model_path(path_or_hf_repo, on_download_progress=on_download_progress)
 
 
   model = load_model_shard(model_path, shard, lazy, model_config)
   model = load_model_shard(model_path, shard, lazy, model_config)
   if adapter_path is not None:
   if adapter_path is not None:

+ 7 - 1
exo/orchestration/standard_node.py

@@ -52,8 +52,13 @@ class StandardNode(Node):
         elif status_data.get("status", "").startswith("end_"):
         elif status_data.get("status", "").startswith("end_"):
           if status_data.get("node_id") == self.current_topology.active_node_id:
           if status_data.get("node_id") == self.current_topology.active_node_id:
             self.current_topology.active_node_id = None
             self.current_topology.active_node_id = None
+      download_progress = None
+      if status_data.get("type", "") == "download_progress":
+        if DEBUG >= 5: print(f"Download progress from {status_data.get('node_id')}: {status_data.get('current')}/{status_data.get('total')} ({round(status_data.get('current') / status_data.get('total') * 100, 2)}%)")
+        if status_data.get("node_id") == self.id:
+          download_progress = (status_data.get('current'), status_data.get('total'))
       if self.topology_viz:
       if self.topology_viz:
-        self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology))
+        self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), download_progress)
     except json.JSONDecodeError:
     except json.JSONDecodeError:
       pass
       pass
 
 
@@ -370,6 +375,7 @@ class StandardNode(Node):
     await asyncio.gather(*[send_result_to_peer(peer) for peer in self.peers], return_exceptions=True)
     await asyncio.gather(*[send_result_to_peer(peer) for peer in self.peers], return_exceptions=True)
 
 
   async def broadcast_opaque_status(self, request_id: str, status: str) -> None:
   async def broadcast_opaque_status(self, request_id: str, status: str) -> None:
+    if DEBUG >= 5: print(f"Broadcasting opaque status: {request_id=} {status=}")
     async def send_status_to_peer(peer):
     async def send_status_to_peer(peer):
       try:
       try:
         await asyncio.wait_for(peer.send_opaque_status(request_id, status), timeout=15.0)
         await asyncio.wait_for(peer.send_opaque_status(request_id, status), timeout=15.0)

+ 5 - 3
exo/viz/topology_viz.py

@@ -1,5 +1,5 @@
 import math
 import math
-from typing import List
+from typing import List, Optional, Tuple
 from exo.helpers import exo_text
 from exo.helpers import exo_text
 from exo.topology.topology import Topology
 from exo.topology.topology import Topology
 from exo.topology.partitioning_strategy import Partition
 from exo.topology.partitioning_strategy import Partition
@@ -17,22 +17,24 @@ class TopologyViz:
     self.web_chat_url = web_chat_url
     self.web_chat_url = web_chat_url
     self.topology = Topology()
     self.topology = Topology()
     self.partitions: List[Partition] = []
     self.partitions: List[Partition] = []
+    self.download_progress = None
 
 
     self.console = Console()
     self.console = Console()
     self.panel = Panel(self._generate_layout(), title="Exo Cluster (0 nodes)", border_style="bright_yellow")
     self.panel = Panel(self._generate_layout(), title="Exo Cluster (0 nodes)", border_style="bright_yellow")
     self.live_panel = Live(self.panel, auto_refresh=False, console=self.console)
     self.live_panel = Live(self.panel, auto_refresh=False, console=self.console)
     self.live_panel.start()
     self.live_panel.start()
 
 
-  def update_visualization(self, topology: Topology, partitions: List[Partition]):
+  def update_visualization(self, topology: Topology, partitions: List[Partition], download_progress: Optional[Tuple[int, int]] = None):
     self.topology = topology
     self.topology = topology
     self.partitions = partitions
     self.partitions = partitions
+    self.download_progress = download_progress
     self.refresh()
     self.refresh()
 
 
   def refresh(self):
   def refresh(self):
     self.panel.renderable = self._generate_layout()
     self.panel.renderable = self._generate_layout()
     # Update the panel title with the number of nodes and partitions
     # Update the panel title with the number of nodes and partitions
     node_count = len(self.topology.nodes)
     node_count = len(self.topology.nodes)
-    self.panel.title = f"Exo Cluster ({node_count} node{'s' if node_count != 1 else ''})"
+    self.panel.title = f"Exo Cluster ({node_count} node{'s' if node_count != 1 else ''}){f' {self.download_progress[0]/self.download_progress[1]:.2%} Downloaded' if self.download_progress else ''}"
     self.live_panel.update(self.panel, refresh=True)
     self.live_panel.update(self.panel, refresh=True)
 
 
   def _generate_layout(self) -> str:
   def _generate_layout(self) -> str:

+ 2 - 0
main.py

@@ -1,6 +1,7 @@
 import argparse
 import argparse
 import asyncio
 import asyncio
 import signal
 import signal
+import json
 import uuid
 import uuid
 from exo.orchestration.standard_node import StandardNode
 from exo.orchestration.standard_node import StandardNode
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.grpc.grpc_server import GRPCServer
@@ -58,6 +59,7 @@ node.on_token.register("main_log").on_next(lambda _, tokens, __: print(inference
 if args.prometheus_client_port:
 if args.prometheus_client_port:
     from exo.stats.metrics import start_metrics_server
     from exo.stats.metrics import start_metrics_server
     start_metrics_server(node, args.prometheus_client_port)
     start_metrics_server(node, args.prometheus_client_port)
+inference_engine.set_on_download_progress(lambda current, total: asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "current": current, "total": total}))))
 
 
 async def shutdown(signal, loop):
 async def shutdown(signal, loop):
     """Gracefully shutdown the server and close the asyncio loop."""
     """Gracefully shutdown the server and close the asyncio loop."""

+ 2 - 1
setup.py

@@ -9,7 +9,8 @@ install_requires = [
     "blobfile==2.1.1",
     "blobfile==2.1.1",
     "grpcio==1.64.1",
     "grpcio==1.64.1",
     "grpcio-tools==1.64.1",
     "grpcio-tools==1.64.1",
-    "huggingface-hub==0.23.4",
+    "hf-transfer==0.1.8",
+    "huggingface-hub==0.24.5",
     "Jinja2==3.1.4",
     "Jinja2==3.1.4",
     "numpy==2.0.0",
     "numpy==2.0.0",
     "pillow==10.4.0",
     "pillow==10.4.0",