Răsfoiți Sursa

bring tinygrad to parity with mlx on llama models, show progress of each download file

Alex Cheema 9 luni în urmă
părinte
comite
d22ed12e7b

+ 13 - 16
exo/api/chatgpt_api.py

@@ -7,6 +7,7 @@ from transformers import AutoTokenizer, AutoProcessor
 from typing import List, Literal, Union, Dict
 from aiohttp import web
 import aiohttp_cors
+import traceback
 from exo import DEBUG, VERSION
 from exo.helpers import terminal_link, PrefixDict
 from exo.inference.shard import Shard
@@ -16,20 +17,22 @@ shard_mappings = {
   ### llama
   "llama-3.1-8b": {
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32),
   },
   "llama-3.1-70b": {
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B", start_layer=0, end_layer=0, n_layers=80),
   },
   "llama-3.1-405b": {
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),
   },
   "llama-3-8b": {
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct", start_layer=0, end_layer=0, n_layers=32),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32),
   },
   "llama-3-70b": {
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct", start_layer=0, end_layer=0, n_layers=80),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80),
   },
   ### mistral
   "mistral-nemo": {
@@ -79,7 +82,7 @@ class ChatCompletionRequest:
 
 async def resolve_tokenizer(model_id: str):
   try:
-    if DEBUG >= 2: print(f"Trying AutoProcessor for {model_id}")
+    if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id}")
     processor = AutoProcessor.from_pretrained(model_id, use_fast=False)
     if not hasattr(processor, 'eos_token_id'):
       processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
@@ -89,21 +92,18 @@ async def resolve_tokenizer(model_id: str):
       processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
     return processor
   except Exception as e:
-    if DEBUG >= 2: print(f"Failed to load processor for {model_id}. Error: {e}")
-    import traceback
+    if DEBUG >= 4: print(f"Failed to load processor for {model_id}. Error: {e}")
 
-    if DEBUG >= 2: print(traceback.format_exc())
+    if DEBUG >= 4: print(traceback.format_exc())
 
   try:
-    if DEBUG >= 2: print(f"Trying AutoTokenizer for {model_id}")
+    if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id}")
     return AutoTokenizer.from_pretrained(model_id)
   except Exception as e:
-    if DEBUG >= 2: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer. Error: {e}")
-    import traceback
+    if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer. Error: {e}")
+    if DEBUG >= 4: print(traceback.format_exc())
 
-    if DEBUG >= 2: print(traceback.format_exc())
-
-  if DEBUG >= 2: print(f"Trying mlx tokenizer for {model_id}")
+  if DEBUG >= 4: print(f"Trying mlx tokenizer for {model_id}")
   from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
 
   return load_tokenizer(await get_model_path(model_id))
@@ -308,10 +308,7 @@ class ChatGPTAPI:
     try:
       await self.node.process_prompt(shard, prompt, image_str, request_id=request_id)
     except Exception as e:
-      if DEBUG >= 2:
-        import traceback
-
-        traceback.print_exc()
+      if DEBUG >= 2: traceback.print_exc()
       return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
 
     try:

+ 24 - 0
exo/helpers.py

@@ -201,3 +201,27 @@ def get_or_create_node_id():
     except Exception as e:
         if DEBUG >= 2: print(f"Unexpected error creating node_id: {e}")
         return str(uuid.uuid4())
+
+def pretty_print_bytes(size_in_bytes: int) -> str:
+    if size_in_bytes < 1024:
+        return f"{size_in_bytes} B"
+    elif size_in_bytes < 1024 ** 2:
+        return f"{size_in_bytes / 1024:.2f} KB"
+    elif size_in_bytes < 1024 ** 3:
+        return f"{size_in_bytes / (1024 ** 2):.2f} MB"
+    elif size_in_bytes < 1024 ** 4:
+        return f"{size_in_bytes / (1024 ** 3):.2f} GB"
+    else:
+        return f"{size_in_bytes / (1024 ** 4):.2f} TB"
+
+def pretty_print_bytes_per_second(bytes_per_second: int) -> str:
+    if bytes_per_second < 1024:
+        return f"{bytes_per_second} B/s"
+    elif bytes_per_second < 1024 ** 2:
+        return f"{bytes_per_second / 1024:.2f} KB/s"
+    elif bytes_per_second < 1024 ** 3:
+        return f"{bytes_per_second / (1024 ** 2):.2f} MB/s"
+    elif bytes_per_second < 1024 ** 4:
+        return f"{bytes_per_second / (1024 ** 3):.2f} GB/s"
+    else:
+        return f"{bytes_per_second / (1024 ** 4):.2f} TB/s"

+ 78 - 24
exo/inference/hf_helpers.py

@@ -103,19 +103,69 @@ async def fetch_file_list(session, repo_id, revision, path=""):
 class HFRepoFileProgressEvent:
     file_path: str
     downloaded: int
+    downloaded_this_session: int
     total: int
-    speed: float
+    speed: int
     eta: timedelta
     status: Literal["not_started", "in_progress", "complete"]
 
+    def to_dict(self):
+        return {
+            "file_path": self.file_path,
+            "downloaded": self.downloaded,
+            "downloaded_this_session": self.downloaded_this_session,
+            "total": self.total,
+            "speed": self.speed,
+            "eta": self.eta.total_seconds(),
+            "status": self.status
+        }
+
+    @classmethod
+    def from_dict(cls, data):
+        # Convert eta from seconds back to timedelta
+        if 'eta' in data:
+            data['eta'] = timedelta(seconds=data['eta'])
+        return cls(**data)
+
 @dataclass
 class HFRepoProgressEvent:
     completed_files: int
     total_files: int
     downloaded_bytes: int
+    downloaded_bytes_this_session: int
     total_bytes: int
+    overall_speed: int
     overall_eta: timedelta
     file_progress: Dict[str, HFRepoFileProgressEvent]
+    status: Literal["not_started", "in_progress", "complete"]
+
+    def to_dict(self):
+        return {
+            "completed_files": self.completed_files,
+            "total_files": self.total_files,
+            "downloaded_bytes": self.downloaded_bytes,
+            "downloaded_bytes_this_session": self.downloaded_bytes_this_session,
+            "total_bytes": self.total_bytes,
+            "overall_speed": self.overall_speed,
+            "overall_eta": self.overall_eta.total_seconds(),
+            "file_progress": {k: v.to_dict() for k, v in self.file_progress.items()},
+            "status": self.status
+        }
+
+    @classmethod
+    def from_dict(cls, data):
+        # Convert overall_eta from seconds back to timedelta
+        if 'overall_eta' in data:
+            data['overall_eta'] = timedelta(seconds=data['overall_eta'])
+
+        # Parse file_progress
+        if 'file_progress' in data:
+            data['file_progress'] = {
+                k: HFRepoFileProgressEvent.from_dict(v)
+                for k, v in data['file_progress'].items()
+            }
+
+        return cls(**data)
 
 HFRepoFileProgressCallback = Callable[[HFRepoFileProgressEvent], Coroutine[Any, Any, None]]
 HFRepoProgressCallback = Callable[[HFRepoProgressEvent], Coroutine[Any, Any, None]]
@@ -143,11 +193,12 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
     async with session.get(url, headers=headers) as response:
         total_size = int(response.headers.get('Content-Length', 0))
         downloaded_size = local_file_size
+        downloaded_this_session = 0
         mode = 'ab' if use_range_request else 'wb'
         if downloaded_size == total_size:
             if DEBUG >= 2: print(f"File already downloaded: {file_path}")
             if progress_callback:
-                await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, total_size, 0, timedelta(0), "complete"))
+                await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
             return
 
         if response.status == 200:
@@ -170,7 +221,7 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
                 if downloaded_size == total_size:
                     if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}")
                     if progress_callback:
-                        await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, total_size, 0, timedelta(0), "complete"))
+                        await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
                     return
             except ValueError:
                 if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
@@ -181,7 +232,7 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
         if downloaded_size == total_size:
             print(f"File already downloaded: {file_path}")
             if progress_callback:
-                await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, total_size, 0, timedelta(0), "complete"))
+                await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
             return
 
         DOWNLOAD_CHUNK_SIZE = 32768
@@ -190,13 +241,15 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
             async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
                 f.write(chunk)
                 downloaded_size += len(chunk)
+                downloaded_this_session += len(chunk)
                 if progress_callback and total_size:
                     elapsed_time = (datetime.now() - start_time).total_seconds()
-                    speed = downloaded_size / elapsed_time if elapsed_time > 0 else 0
+                    speed = int(downloaded_this_session / elapsed_time) if elapsed_time > 0 else 0
                     remaining_size = total_size - downloaded_size
                     eta = timedelta(seconds=remaining_size / speed) if speed > 0 else timedelta(0)
                     status = "in_progress" if downloaded_size < total_size else "complete"
-                    await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, total_size, speed, eta, status))
+                    if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}")
+                    await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
         if DEBUG >= 2: print(f"Downloaded: {file_path}")
 
 async def download_all_files(repo_id: str, revision: str = "main", progress_callback: Optional[HFRepoProgressCallback] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None):
@@ -229,35 +282,36 @@ async def download_all_files(repo_id: str, revision: str = "main", progress_call
         file_list = await fetch_file_list(session, repo_id, revision)
         filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
         total_files = len(filtered_file_list)
-        completed_files = 0
         total_bytes = sum(file["size"] for file in filtered_file_list)
-        downloaded_bytes = 0
-        file_progress: Dict[str, HFRepoFileProgressEvent] = {file["path"]: HFRepoFileProgressEvent(file["path"], 0, file["size"], 0, timedelta(0), "not_started") for file in filtered_file_list}
+        file_progress: Dict[str, HFRepoFileProgressEvent] = {file["path"]: HFRepoFileProgressEvent(file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started") for file in filtered_file_list}
         start_time = datetime.now()
 
-        async def download_with_progress(file_info):
-            nonlocal completed_files, downloaded_bytes, file_progress
-
+        async def download_with_progress(file_info, progress_state):
             async def file_progress_callback(event: HFRepoFileProgressEvent):
-                nonlocal downloaded_bytes, file_progress
-                downloaded_bytes += event.downloaded - file_progress[event.file_path].downloaded
+                progress_state['downloaded_bytes'] += event.downloaded - file_progress[event.file_path].downloaded
+                progress_state['downloaded_bytes_this_session'] += event.downloaded_this_session - file_progress[event.file_path].downloaded_this_session
                 file_progress[event.file_path] = event
                 if progress_callback:
                     elapsed_time = (datetime.now() - start_time).total_seconds()
-                    overall_speed = downloaded_bytes / elapsed_time if elapsed_time > 0 else 0
-                    overall_eta = timedelta(seconds=(total_bytes - downloaded_bytes) / overall_speed) if overall_speed > 0 else timedelta(0)
-                    await progress_callback(HFRepoProgressEvent(completed_files, total_files, downloaded_bytes, total_bytes, overall_eta, file_progress))
+                    overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
+                    remaining_bytes = total_bytes - progress_state['downloaded_bytes']
+                    overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
+                    status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete"
+                    await progress_callback(HFRepoProgressEvent(progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
 
             await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
-            completed_files += 1
-            file_progress[file_info["path"]] = HFRepoFileProgressEvent(file_info["path"], file_info["size"], file_info["size"], 0, timedelta(0), "complete")
+            progress_state['completed_files'] += 1
+            file_progress[file_info["path"]] = HFRepoFileProgressEvent(file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete")
             if progress_callback:
                 elapsed_time = (datetime.now() - start_time).total_seconds()
-                overall_speed = downloaded_bytes / elapsed_time if elapsed_time > 0 else 0
-                overall_eta = timedelta(seconds=(total_bytes - downloaded_bytes) / overall_speed) if overall_speed > 0 else timedelta(0)
-                await progress_callback(HFRepoProgressEvent(completed_files, total_files, downloaded_bytes, total_bytes, overall_eta, file_progress))
-
-        tasks = [download_with_progress(file_info) for file_info in filtered_file_list]
+                overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
+                remaining_bytes = total_bytes - progress_state['downloaded_bytes']
+                overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
+                status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
+                await progress_callback(HFRepoProgressEvent(progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
+
+        progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0}
+        tasks = [download_with_progress(file_info, progress_state) for file_info in filtered_file_list]
         await asyncio.gather(*tasks)
 
     return snapshot_dir

+ 3 - 4
exo/inference/tinygrad/inference.py

@@ -1,16 +1,14 @@
-from functools import partial
 from pathlib import Path
 from typing import List, Optional, Union, Callable, Coroutine, Any
 import json
-from tiktoken.load import load_tiktoken_bpe
 from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
 from tinygrad.nn.state import safe_load, torch_load, load_state_dict
 from tinygrad import Tensor, nn, Context, GlobalCounters
-from tinygrad.helpers import DEBUG, tqdm, _cache_dir, fetch
+from tinygrad.helpers import tqdm
 from exo.inference.shard import Shard
 from exo.inference.inference_engine import InferenceEngine
 import numpy as np
-from exo.inference.hf_helpers import HFRepoProgressCallback, HFRepoProgressEvent, download_all_files, get_repo_root
+from exo.inference.hf_helpers import HFRepoProgressCallback, HFRepoProgressEvent, download_all_files
 
 MODEL_PARAMS = {
   "8B": {
@@ -88,6 +86,7 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", quantize=
       )
   else:
     weights = load(str(model_path))
+
   if "model.embed_tokens.weight" in weights:
     weights = convert_from_huggingface(
       weights,

+ 7 - 9
exo/orchestration/standard_node.py

@@ -3,6 +3,7 @@ import json
 import asyncio
 import uuid
 import time
+import traceback
 from typing import List, Dict, Optional, Tuple, Union
 from exo.networking import Discovery, PeerHandle, Server
 from exo.inference.inference_engine import InferenceEngine, Shard
@@ -13,6 +14,7 @@ from exo.topology.partitioning_strategy import Partition, PartitioningStrategy,
 from exo import DEBUG
 from exo.helpers import AsyncCallbackSystem
 from exo.viz.topology_viz import TopologyViz
+from exo.inference.hf_helpers import HFRepoProgressEvent
 
 
 class StandardNode(Node):
@@ -54,12 +56,14 @@ class StandardNode(Node):
             self.current_topology.active_node_id = None
       download_progress = None
       if status_data.get("type", "") == "download_progress":
-        if DEBUG >= 5: print(f"Download progress from {status_data.get('node_id')}: {status_data.get('current')}/{status_data.get('total')} ({round(status_data.get('current') / status_data.get('total') * 100, 2)}%)")
+        if DEBUG >= 5: print(f"Download progress from {status_data.get('node_id')}: {status_data.get('progress')}")
         if status_data.get("node_id") == self.id:
-          download_progress = (status_data.get('current'), status_data.get('total'))
+          download_progress = HFRepoProgressEvent.from_dict(status_data.get('progress'))
       if self.topology_viz:
         self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), download_progress)
-    except json.JSONDecodeError:
+    except Exception as e:
+      if DEBUG >= 1: print(f"Error updating visualization: {e}")
+      traceback.print_exc()
       pass
 
   async def start(self, wait_for_peers: int = 0) -> None:
@@ -231,8 +235,6 @@ class StandardNode(Node):
       return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
     except Exception as e:
       print(f"Error processing tensor for shard {shard}: {e}")
-      import traceback
-
       traceback.print_exc()
       return None
 
@@ -368,8 +370,6 @@ class StandardNode(Node):
         print(f"Timeout broadcasting result to {peer.id()}")
       except Exception as e:
         print(f"Error broadcasting result to {peer.id()}: {e}")
-        import traceback
-
         traceback.print_exc()
 
     await asyncio.gather(*[send_result_to_peer(peer) for peer in self.peers], return_exceptions=True)
@@ -383,8 +383,6 @@ class StandardNode(Node):
         print(f"Timeout sending opaque status to {peer.id()}")
       except Exception as e:
         print(f"Error sending opaque status to {peer.id()}: {e}")
-        import traceback
-
         traceback.print_exc()
 
     await asyncio.gather(*[send_status_to_peer(peer) for peer in self.peers], return_exceptions=True)

+ 56 - 4
exo/viz/topology_viz.py

@@ -1,6 +1,6 @@
 import math
 from typing import List, Optional, Tuple
-from exo.helpers import exo_text
+from exo.helpers import exo_text, pretty_print_bytes, pretty_print_bytes_per_second
 from exo.topology.topology import Topology
 from exo.topology.partitioning_strategy import Partition
 from rich.console import Console
@@ -8,8 +8,10 @@ from rich.panel import Panel
 from rich.text import Text
 from rich.live import Live
 from rich.style import Style
+from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn
+from rich.table import Table
 from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
-
+from exo.inference.hf_helpers import HFRepoProgressEvent
 
 class TopologyViz:
   def __init__(self, chatgpt_api_endpoint: str = None, web_chat_url: str = None):
@@ -24,7 +26,7 @@ class TopologyViz:
     self.live_panel = Live(self.panel, auto_refresh=False, console=self.console)
     self.live_panel.start()
 
-  def update_visualization(self, topology: Topology, partitions: List[Partition], download_progress: Optional[Tuple[int, int]] = None):
+  def update_visualization(self, topology: Topology, partitions: List[Partition], download_progress: HFRepoProgressEvent = None):
     self.topology = topology
     self.partitions = partitions
     self.download_progress = download_progress
@@ -34,7 +36,7 @@ class TopologyViz:
     self.panel.renderable = self._generate_layout()
     # Update the panel title with the number of nodes and partitions
     node_count = len(self.topology.nodes)
-    self.panel.title = f"Exo Cluster ({node_count} node{'s' if node_count != 1 else ''}){f' {self.download_progress[0]/self.download_progress[1]:.2%} Downloaded' if self.download_progress else ''}"
+    self.panel.title = f"Exo Cluster ({node_count} node{'s' if node_count != 1 else ''})"
     self.live_panel.update(self.panel, refresh=True)
 
   def _generate_layout(self) -> str:
@@ -47,6 +49,31 @@ class TopologyViz:
     # Generate visualization
     visualization = [[" " for _ in range(100)] for _ in range(55)]  # Decreased height
 
+    # Draw download first so everything else is drawn on top
+    # If a download is in progress, show the download info summary
+    if self.download_progress and self.download_progress.status != "complete":
+        download_summary = _generate_download_summary(self.download_progress)
+        download_panel = Panel(
+            download_summary,
+            title="Download Progress",
+            border_style="cyan",
+            expand=False,
+            width=96,  # Further reduced to ensure it fits within the visualization
+            height=None  # Allow the panel to adjust its height based on content
+        )
+        console = Console(width=98, height=55)  # Reduced console width
+        with console.capture() as capture:
+            console.print(download_panel)
+        download_lines = capture.get().split('\n')
+        download_start_y = 15
+        panel_width = len(max(download_lines, key=len))
+        start_x = max(1, (100 - panel_width) // 2)  # Ensure start_x is at least 1 to avoid left border cut-off
+        for i, line in enumerate(download_lines):
+            for j, char in enumerate(line):
+                if 1 <= start_x + j < 99 and download_start_y + i < 55:  # Ensure we don't write to the rightmost column
+                    visualization[download_start_y + i][start_x + j] = char
+
+
     # Add exo_text at the top in bright yellow
     exo_lines = exo_text.split("\n")
     yellow_style = Style(color="bright_yellow")
@@ -168,3 +195,28 @@ class TopologyViz:
 
     # Convert to string
     return "\n".join("".join(str(char) for char in row) for row in visualization)
+
+def _generate_download_summary(download_progress) -> Table:
+    summary = Table(show_header=False, box=None, padding=(0, 1))
+    summary.add_column("Info", style="cyan", no_wrap=True)
+    summary.add_column("Progress", style="cyan", no_wrap=True)
+    summary.add_column("Percentage", style="cyan", no_wrap=True)
+
+    title = f"Downloading model ({download_progress.completed_files}/{download_progress.total_files}):"
+    summary.add_row(Text(title, style="bold"))
+    progress_info = f"{pretty_print_bytes(download_progress.downloaded_bytes)} / {pretty_print_bytes(download_progress.total_bytes)} ({pretty_print_bytes_per_second(download_progress.overall_speed)})"
+    summary.add_row(progress_info)
+
+    eta_info = f"ETA: {download_progress.overall_eta}"
+    summary.add_row(eta_info)
+
+    summary.add_row("")  # Empty row for spacing
+
+    for file_path, file_progress in download_progress.file_progress.items():
+      if file_progress.status != "complete":
+        progress = int(file_progress.downloaded / file_progress.total * 20)  # Increased bar width
+        bar = f"[{'=' * progress}{' ' * (20 - progress)}]"
+        percentage = f"{file_progress.downloaded / file_progress.total * 100:.0f}%"
+        summary.add_row(Text(file_path[:20], style="cyan"), bar, percentage)  # Increased file path length
+
+    return summary

+ 1 - 1
main.py

@@ -60,7 +60,7 @@ node.on_token.register("main_log").on_next(lambda _, tokens, __: print(inference
 if args.prometheus_client_port:
     from exo.stats.metrics import start_metrics_server
     start_metrics_server(node, args.prometheus_client_port)
-inference_engine.set_progress_callback(lambda event: asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "current": event.downloaded_bytes, "total": event.total_bytes}))))
+inference_engine.set_progress_callback(lambda event: asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()}))))
 
 async def shutdown(signal, loop):
     """Gracefully shutdown the server and close the asyncio loop."""