Răsfoiți Sursa

bring tinygrad to parity with mlx on llama models, show progress of each download file

Alex Cheema 1 an în urmă
părinte
comite
d22ed12e7b

+ 13 - 16
exo/api/chatgpt_api.py

@@ -7,6 +7,7 @@ from transformers import AutoTokenizer, AutoProcessor
 from typing import List, Literal, Union, Dict
 from typing import List, Literal, Union, Dict
 from aiohttp import web
 from aiohttp import web
 import aiohttp_cors
 import aiohttp_cors
+import traceback
 from exo import DEBUG, VERSION
 from exo import DEBUG, VERSION
 from exo.helpers import terminal_link, PrefixDict
 from exo.helpers import terminal_link, PrefixDict
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
@@ -16,20 +17,22 @@ shard_mappings = {
   ### llama
   ### llama
   "llama-3.1-8b": {
   "llama-3.1-8b": {
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32),
   },
   },
   "llama-3.1-70b": {
   "llama-3.1-70b": {
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B", start_layer=0, end_layer=0, n_layers=80),
   },
   },
   "llama-3.1-405b": {
   "llama-3.1-405b": {
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),
   },
   },
   "llama-3-8b": {
   "llama-3-8b": {
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct", start_layer=0, end_layer=0, n_layers=32),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32),
   },
   },
   "llama-3-70b": {
   "llama-3-70b": {
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct", start_layer=0, end_layer=0, n_layers=80),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80),
   },
   },
   ### mistral
   ### mistral
   "mistral-nemo": {
   "mistral-nemo": {
@@ -79,7 +82,7 @@ class ChatCompletionRequest:
 
 
 async def resolve_tokenizer(model_id: str):
 async def resolve_tokenizer(model_id: str):
   try:
   try:
-    if DEBUG >= 2: print(f"Trying AutoProcessor for {model_id}")
+    if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id}")
     processor = AutoProcessor.from_pretrained(model_id, use_fast=False)
     processor = AutoProcessor.from_pretrained(model_id, use_fast=False)
     if not hasattr(processor, 'eos_token_id'):
     if not hasattr(processor, 'eos_token_id'):
       processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
       processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
@@ -89,21 +92,18 @@ async def resolve_tokenizer(model_id: str):
       processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
       processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
     return processor
     return processor
   except Exception as e:
   except Exception as e:
-    if DEBUG >= 2: print(f"Failed to load processor for {model_id}. Error: {e}")
-    import traceback
+    if DEBUG >= 4: print(f"Failed to load processor for {model_id}. Error: {e}")
 
 
-    if DEBUG >= 2: print(traceback.format_exc())
+    if DEBUG >= 4: print(traceback.format_exc())
 
 
   try:
   try:
-    if DEBUG >= 2: print(f"Trying AutoTokenizer for {model_id}")
+    if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id}")
     return AutoTokenizer.from_pretrained(model_id)
     return AutoTokenizer.from_pretrained(model_id)
   except Exception as e:
   except Exception as e:
-    if DEBUG >= 2: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer. Error: {e}")
-    import traceback
+    if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer. Error: {e}")
+    if DEBUG >= 4: print(traceback.format_exc())
 
 
-    if DEBUG >= 2: print(traceback.format_exc())
-
-  if DEBUG >= 2: print(f"Trying mlx tokenizer for {model_id}")
+  if DEBUG >= 4: print(f"Trying mlx tokenizer for {model_id}")
   from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
   from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
 
 
   return load_tokenizer(await get_model_path(model_id))
   return load_tokenizer(await get_model_path(model_id))
@@ -308,10 +308,7 @@ class ChatGPTAPI:
     try:
     try:
       await self.node.process_prompt(shard, prompt, image_str, request_id=request_id)
       await self.node.process_prompt(shard, prompt, image_str, request_id=request_id)
     except Exception as e:
     except Exception as e:
-      if DEBUG >= 2:
-        import traceback
-
-        traceback.print_exc()
+      if DEBUG >= 2: traceback.print_exc()
       return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
       return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
 
 
     try:
     try:

+ 24 - 0
exo/helpers.py

@@ -201,3 +201,27 @@ def get_or_create_node_id():
     except Exception as e:
     except Exception as e:
         if DEBUG >= 2: print(f"Unexpected error creating node_id: {e}")
         if DEBUG >= 2: print(f"Unexpected error creating node_id: {e}")
         return str(uuid.uuid4())
         return str(uuid.uuid4())
+
+def pretty_print_bytes(size_in_bytes: int) -> str:
+    if size_in_bytes < 1024:
+        return f"{size_in_bytes} B"
+    elif size_in_bytes < 1024 ** 2:
+        return f"{size_in_bytes / 1024:.2f} KB"
+    elif size_in_bytes < 1024 ** 3:
+        return f"{size_in_bytes / (1024 ** 2):.2f} MB"
+    elif size_in_bytes < 1024 ** 4:
+        return f"{size_in_bytes / (1024 ** 3):.2f} GB"
+    else:
+        return f"{size_in_bytes / (1024 ** 4):.2f} TB"
+
+def pretty_print_bytes_per_second(bytes_per_second: int) -> str:
+    if bytes_per_second < 1024:
+        return f"{bytes_per_second} B/s"
+    elif bytes_per_second < 1024 ** 2:
+        return f"{bytes_per_second / 1024:.2f} KB/s"
+    elif bytes_per_second < 1024 ** 3:
+        return f"{bytes_per_second / (1024 ** 2):.2f} MB/s"
+    elif bytes_per_second < 1024 ** 4:
+        return f"{bytes_per_second / (1024 ** 3):.2f} GB/s"
+    else:
+        return f"{bytes_per_second / (1024 ** 4):.2f} TB/s"

+ 78 - 24
exo/inference/hf_helpers.py

@@ -103,19 +103,69 @@ async def fetch_file_list(session, repo_id, revision, path=""):
 class HFRepoFileProgressEvent:
 class HFRepoFileProgressEvent:
     file_path: str
     file_path: str
     downloaded: int
     downloaded: int
+    downloaded_this_session: int
     total: int
     total: int
-    speed: float
+    speed: int
     eta: timedelta
     eta: timedelta
     status: Literal["not_started", "in_progress", "complete"]
     status: Literal["not_started", "in_progress", "complete"]
 
 
+    def to_dict(self):
+        return {
+            "file_path": self.file_path,
+            "downloaded": self.downloaded,
+            "downloaded_this_session": self.downloaded_this_session,
+            "total": self.total,
+            "speed": self.speed,
+            "eta": self.eta.total_seconds(),
+            "status": self.status
+        }
+
+    @classmethod
+    def from_dict(cls, data):
+        # Convert eta from seconds back to timedelta
+        if 'eta' in data:
+            data['eta'] = timedelta(seconds=data['eta'])
+        return cls(**data)
+
 @dataclass
 @dataclass
 class HFRepoProgressEvent:
 class HFRepoProgressEvent:
     completed_files: int
     completed_files: int
     total_files: int
     total_files: int
     downloaded_bytes: int
     downloaded_bytes: int
+    downloaded_bytes_this_session: int
     total_bytes: int
     total_bytes: int
+    overall_speed: int
     overall_eta: timedelta
     overall_eta: timedelta
     file_progress: Dict[str, HFRepoFileProgressEvent]
     file_progress: Dict[str, HFRepoFileProgressEvent]
+    status: Literal["not_started", "in_progress", "complete"]
+
+    def to_dict(self):
+        return {
+            "completed_files": self.completed_files,
+            "total_files": self.total_files,
+            "downloaded_bytes": self.downloaded_bytes,
+            "downloaded_bytes_this_session": self.downloaded_bytes_this_session,
+            "total_bytes": self.total_bytes,
+            "overall_speed": self.overall_speed,
+            "overall_eta": self.overall_eta.total_seconds(),
+            "file_progress": {k: v.to_dict() for k, v in self.file_progress.items()},
+            "status": self.status
+        }
+
+    @classmethod
+    def from_dict(cls, data):
+        # Convert overall_eta from seconds back to timedelta
+        if 'overall_eta' in data:
+            data['overall_eta'] = timedelta(seconds=data['overall_eta'])
+
+        # Parse file_progress
+        if 'file_progress' in data:
+            data['file_progress'] = {
+                k: HFRepoFileProgressEvent.from_dict(v)
+                for k, v in data['file_progress'].items()
+            }
+
+        return cls(**data)
 
 
 HFRepoFileProgressCallback = Callable[[HFRepoFileProgressEvent], Coroutine[Any, Any, None]]
 HFRepoFileProgressCallback = Callable[[HFRepoFileProgressEvent], Coroutine[Any, Any, None]]
 HFRepoProgressCallback = Callable[[HFRepoProgressEvent], Coroutine[Any, Any, None]]
 HFRepoProgressCallback = Callable[[HFRepoProgressEvent], Coroutine[Any, Any, None]]
@@ -143,11 +193,12 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
     async with session.get(url, headers=headers) as response:
     async with session.get(url, headers=headers) as response:
         total_size = int(response.headers.get('Content-Length', 0))
         total_size = int(response.headers.get('Content-Length', 0))
         downloaded_size = local_file_size
         downloaded_size = local_file_size
+        downloaded_this_session = 0
         mode = 'ab' if use_range_request else 'wb'
         mode = 'ab' if use_range_request else 'wb'
         if downloaded_size == total_size:
         if downloaded_size == total_size:
             if DEBUG >= 2: print(f"File already downloaded: {file_path}")
             if DEBUG >= 2: print(f"File already downloaded: {file_path}")
             if progress_callback:
             if progress_callback:
-                await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, total_size, 0, timedelta(0), "complete"))
+                await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
             return
             return
 
 
         if response.status == 200:
         if response.status == 200:
@@ -170,7 +221,7 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
                 if downloaded_size == total_size:
                 if downloaded_size == total_size:
                     if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}")
                     if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}")
                     if progress_callback:
                     if progress_callback:
-                        await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, total_size, 0, timedelta(0), "complete"))
+                        await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
                     return
                     return
             except ValueError:
             except ValueError:
                 if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
                 if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
@@ -181,7 +232,7 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
         if downloaded_size == total_size:
         if downloaded_size == total_size:
             print(f"File already downloaded: {file_path}")
             print(f"File already downloaded: {file_path}")
             if progress_callback:
             if progress_callback:
-                await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, total_size, 0, timedelta(0), "complete"))
+                await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
             return
             return
 
 
         DOWNLOAD_CHUNK_SIZE = 32768
         DOWNLOAD_CHUNK_SIZE = 32768
@@ -190,13 +241,15 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
             async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
             async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
                 f.write(chunk)
                 f.write(chunk)
                 downloaded_size += len(chunk)
                 downloaded_size += len(chunk)
+                downloaded_this_session += len(chunk)
                 if progress_callback and total_size:
                 if progress_callback and total_size:
                     elapsed_time = (datetime.now() - start_time).total_seconds()
                     elapsed_time = (datetime.now() - start_time).total_seconds()
-                    speed = downloaded_size / elapsed_time if elapsed_time > 0 else 0
+                    speed = int(downloaded_this_session / elapsed_time) if elapsed_time > 0 else 0
                     remaining_size = total_size - downloaded_size
                     remaining_size = total_size - downloaded_size
                     eta = timedelta(seconds=remaining_size / speed) if speed > 0 else timedelta(0)
                     eta = timedelta(seconds=remaining_size / speed) if speed > 0 else timedelta(0)
                     status = "in_progress" if downloaded_size < total_size else "complete"
                     status = "in_progress" if downloaded_size < total_size else "complete"
-                    await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, total_size, speed, eta, status))
+                    if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}")
+                    await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
         if DEBUG >= 2: print(f"Downloaded: {file_path}")
         if DEBUG >= 2: print(f"Downloaded: {file_path}")
 
 
 async def download_all_files(repo_id: str, revision: str = "main", progress_callback: Optional[HFRepoProgressCallback] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None):
 async def download_all_files(repo_id: str, revision: str = "main", progress_callback: Optional[HFRepoProgressCallback] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None):
@@ -229,35 +282,36 @@ async def download_all_files(repo_id: str, revision: str = "main", progress_call
         file_list = await fetch_file_list(session, repo_id, revision)
         file_list = await fetch_file_list(session, repo_id, revision)
         filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
         filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
         total_files = len(filtered_file_list)
         total_files = len(filtered_file_list)
-        completed_files = 0
         total_bytes = sum(file["size"] for file in filtered_file_list)
         total_bytes = sum(file["size"] for file in filtered_file_list)
-        downloaded_bytes = 0
-        file_progress: Dict[str, HFRepoFileProgressEvent] = {file["path"]: HFRepoFileProgressEvent(file["path"], 0, file["size"], 0, timedelta(0), "not_started") for file in filtered_file_list}
+        file_progress: Dict[str, HFRepoFileProgressEvent] = {file["path"]: HFRepoFileProgressEvent(file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started") for file in filtered_file_list}
         start_time = datetime.now()
         start_time = datetime.now()
 
 
-        async def download_with_progress(file_info):
-            nonlocal completed_files, downloaded_bytes, file_progress
-
+        async def download_with_progress(file_info, progress_state):
             async def file_progress_callback(event: HFRepoFileProgressEvent):
             async def file_progress_callback(event: HFRepoFileProgressEvent):
-                nonlocal downloaded_bytes, file_progress
-                downloaded_bytes += event.downloaded - file_progress[event.file_path].downloaded
+                progress_state['downloaded_bytes'] += event.downloaded - file_progress[event.file_path].downloaded
+                progress_state['downloaded_bytes_this_session'] += event.downloaded_this_session - file_progress[event.file_path].downloaded_this_session
                 file_progress[event.file_path] = event
                 file_progress[event.file_path] = event
                 if progress_callback:
                 if progress_callback:
                     elapsed_time = (datetime.now() - start_time).total_seconds()
                     elapsed_time = (datetime.now() - start_time).total_seconds()
-                    overall_speed = downloaded_bytes / elapsed_time if elapsed_time > 0 else 0
-                    overall_eta = timedelta(seconds=(total_bytes - downloaded_bytes) / overall_speed) if overall_speed > 0 else timedelta(0)
-                    await progress_callback(HFRepoProgressEvent(completed_files, total_files, downloaded_bytes, total_bytes, overall_eta, file_progress))
+                    overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
+                    remaining_bytes = total_bytes - progress_state['downloaded_bytes']
+                    overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
+                    status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete"
+                    await progress_callback(HFRepoProgressEvent(progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
 
 
             await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
             await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
-            completed_files += 1
-            file_progress[file_info["path"]] = HFRepoFileProgressEvent(file_info["path"], file_info["size"], file_info["size"], 0, timedelta(0), "complete")
+            progress_state['completed_files'] += 1
+            file_progress[file_info["path"]] = HFRepoFileProgressEvent(file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete")
             if progress_callback:
             if progress_callback:
                 elapsed_time = (datetime.now() - start_time).total_seconds()
                 elapsed_time = (datetime.now() - start_time).total_seconds()
-                overall_speed = downloaded_bytes / elapsed_time if elapsed_time > 0 else 0
-                overall_eta = timedelta(seconds=(total_bytes - downloaded_bytes) / overall_speed) if overall_speed > 0 else timedelta(0)
-                await progress_callback(HFRepoProgressEvent(completed_files, total_files, downloaded_bytes, total_bytes, overall_eta, file_progress))
-
-        tasks = [download_with_progress(file_info) for file_info in filtered_file_list]
+                overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
+                remaining_bytes = total_bytes - progress_state['downloaded_bytes']
+                overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
+                status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
+                await progress_callback(HFRepoProgressEvent(progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
+
+        progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0}
+        tasks = [download_with_progress(file_info, progress_state) for file_info in filtered_file_list]
         await asyncio.gather(*tasks)
         await asyncio.gather(*tasks)
 
 
     return snapshot_dir
     return snapshot_dir

+ 3 - 4
exo/inference/tinygrad/inference.py

@@ -1,16 +1,14 @@
-from functools import partial
 from pathlib import Path
 from pathlib import Path
 from typing import List, Optional, Union, Callable, Coroutine, Any
 from typing import List, Optional, Union, Callable, Coroutine, Any
 import json
 import json
-from tiktoken.load import load_tiktoken_bpe
 from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
 from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
 from tinygrad.nn.state import safe_load, torch_load, load_state_dict
 from tinygrad.nn.state import safe_load, torch_load, load_state_dict
 from tinygrad import Tensor, nn, Context, GlobalCounters
 from tinygrad import Tensor, nn, Context, GlobalCounters
-from tinygrad.helpers import DEBUG, tqdm, _cache_dir, fetch
+from tinygrad.helpers import tqdm
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.inference_engine import InferenceEngine
 import numpy as np
 import numpy as np
-from exo.inference.hf_helpers import HFRepoProgressCallback, HFRepoProgressEvent, download_all_files, get_repo_root
+from exo.inference.hf_helpers import HFRepoProgressCallback, HFRepoProgressEvent, download_all_files
 
 
 MODEL_PARAMS = {
 MODEL_PARAMS = {
   "8B": {
   "8B": {
@@ -88,6 +86,7 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", quantize=
       )
       )
   else:
   else:
     weights = load(str(model_path))
     weights = load(str(model_path))
+
   if "model.embed_tokens.weight" in weights:
   if "model.embed_tokens.weight" in weights:
     weights = convert_from_huggingface(
     weights = convert_from_huggingface(
       weights,
       weights,

+ 7 - 9
exo/orchestration/standard_node.py

@@ -3,6 +3,7 @@ import json
 import asyncio
 import asyncio
 import uuid
 import uuid
 import time
 import time
+import traceback
 from typing import List, Dict, Optional, Tuple, Union
 from typing import List, Dict, Optional, Tuple, Union
 from exo.networking import Discovery, PeerHandle, Server
 from exo.networking import Discovery, PeerHandle, Server
 from exo.inference.inference_engine import InferenceEngine, Shard
 from exo.inference.inference_engine import InferenceEngine, Shard
@@ -13,6 +14,7 @@ from exo.topology.partitioning_strategy import Partition, PartitioningStrategy,
 from exo import DEBUG
 from exo import DEBUG
 from exo.helpers import AsyncCallbackSystem
 from exo.helpers import AsyncCallbackSystem
 from exo.viz.topology_viz import TopologyViz
 from exo.viz.topology_viz import TopologyViz
+from exo.inference.hf_helpers import HFRepoProgressEvent
 
 
 
 
 class StandardNode(Node):
 class StandardNode(Node):
@@ -54,12 +56,14 @@ class StandardNode(Node):
             self.current_topology.active_node_id = None
             self.current_topology.active_node_id = None
       download_progress = None
       download_progress = None
       if status_data.get("type", "") == "download_progress":
       if status_data.get("type", "") == "download_progress":
-        if DEBUG >= 5: print(f"Download progress from {status_data.get('node_id')}: {status_data.get('current')}/{status_data.get('total')} ({round(status_data.get('current') / status_data.get('total') * 100, 2)}%)")
+        if DEBUG >= 5: print(f"Download progress from {status_data.get('node_id')}: {status_data.get('progress')}")
         if status_data.get("node_id") == self.id:
         if status_data.get("node_id") == self.id:
-          download_progress = (status_data.get('current'), status_data.get('total'))
+          download_progress = HFRepoProgressEvent.from_dict(status_data.get('progress'))
       if self.topology_viz:
       if self.topology_viz:
         self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), download_progress)
         self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), download_progress)
-    except json.JSONDecodeError:
+    except Exception as e:
+      if DEBUG >= 1: print(f"Error updating visualization: {e}")
+      traceback.print_exc()
       pass
       pass
 
 
   async def start(self, wait_for_peers: int = 0) -> None:
   async def start(self, wait_for_peers: int = 0) -> None:
@@ -231,8 +235,6 @@ class StandardNode(Node):
       return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
       return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
     except Exception as e:
     except Exception as e:
       print(f"Error processing tensor for shard {shard}: {e}")
       print(f"Error processing tensor for shard {shard}: {e}")
-      import traceback
-
       traceback.print_exc()
       traceback.print_exc()
       return None
       return None
 
 
@@ -368,8 +370,6 @@ class StandardNode(Node):
         print(f"Timeout broadcasting result to {peer.id()}")
         print(f"Timeout broadcasting result to {peer.id()}")
       except Exception as e:
       except Exception as e:
         print(f"Error broadcasting result to {peer.id()}: {e}")
         print(f"Error broadcasting result to {peer.id()}: {e}")
-        import traceback
-
         traceback.print_exc()
         traceback.print_exc()
 
 
     await asyncio.gather(*[send_result_to_peer(peer) for peer in self.peers], return_exceptions=True)
     await asyncio.gather(*[send_result_to_peer(peer) for peer in self.peers], return_exceptions=True)
@@ -383,8 +383,6 @@ class StandardNode(Node):
         print(f"Timeout sending opaque status to {peer.id()}")
         print(f"Timeout sending opaque status to {peer.id()}")
       except Exception as e:
       except Exception as e:
         print(f"Error sending opaque status to {peer.id()}: {e}")
         print(f"Error sending opaque status to {peer.id()}: {e}")
-        import traceback
-
         traceback.print_exc()
         traceback.print_exc()
 
 
     await asyncio.gather(*[send_status_to_peer(peer) for peer in self.peers], return_exceptions=True)
     await asyncio.gather(*[send_status_to_peer(peer) for peer in self.peers], return_exceptions=True)

+ 56 - 4
exo/viz/topology_viz.py

@@ -1,6 +1,6 @@
 import math
 import math
 from typing import List, Optional, Tuple
 from typing import List, Optional, Tuple
-from exo.helpers import exo_text
+from exo.helpers import exo_text, pretty_print_bytes, pretty_print_bytes_per_second
 from exo.topology.topology import Topology
 from exo.topology.topology import Topology
 from exo.topology.partitioning_strategy import Partition
 from exo.topology.partitioning_strategy import Partition
 from rich.console import Console
 from rich.console import Console
@@ -8,8 +8,10 @@ from rich.panel import Panel
 from rich.text import Text
 from rich.text import Text
 from rich.live import Live
 from rich.live import Live
 from rich.style import Style
 from rich.style import Style
+from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn
+from rich.table import Table
 from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
 from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
-
+from exo.inference.hf_helpers import HFRepoProgressEvent
 
 
 class TopologyViz:
 class TopologyViz:
   def __init__(self, chatgpt_api_endpoint: str = None, web_chat_url: str = None):
   def __init__(self, chatgpt_api_endpoint: str = None, web_chat_url: str = None):
@@ -24,7 +26,7 @@ class TopologyViz:
     self.live_panel = Live(self.panel, auto_refresh=False, console=self.console)
     self.live_panel = Live(self.panel, auto_refresh=False, console=self.console)
     self.live_panel.start()
     self.live_panel.start()
 
 
-  def update_visualization(self, topology: Topology, partitions: List[Partition], download_progress: Optional[Tuple[int, int]] = None):
+  def update_visualization(self, topology: Topology, partitions: List[Partition], download_progress: HFRepoProgressEvent = None):
     self.topology = topology
     self.topology = topology
     self.partitions = partitions
     self.partitions = partitions
     self.download_progress = download_progress
     self.download_progress = download_progress
@@ -34,7 +36,7 @@ class TopologyViz:
     self.panel.renderable = self._generate_layout()
     self.panel.renderable = self._generate_layout()
     # Update the panel title with the number of nodes and partitions
     # Update the panel title with the number of nodes and partitions
     node_count = len(self.topology.nodes)
     node_count = len(self.topology.nodes)
-    self.panel.title = f"Exo Cluster ({node_count} node{'s' if node_count != 1 else ''}){f' {self.download_progress[0]/self.download_progress[1]:.2%} Downloaded' if self.download_progress else ''}"
+    self.panel.title = f"Exo Cluster ({node_count} node{'s' if node_count != 1 else ''})"
     self.live_panel.update(self.panel, refresh=True)
     self.live_panel.update(self.panel, refresh=True)
 
 
   def _generate_layout(self) -> str:
   def _generate_layout(self) -> str:
@@ -47,6 +49,31 @@ class TopologyViz:
     # Generate visualization
     # Generate visualization
     visualization = [[" " for _ in range(100)] for _ in range(55)]  # Decreased height
     visualization = [[" " for _ in range(100)] for _ in range(55)]  # Decreased height
 
 
+    # Draw download first so everything else is drawn on top
+    # If a download is in progress, show the download info summary
+    if self.download_progress and self.download_progress.status != "complete":
+        download_summary = _generate_download_summary(self.download_progress)
+        download_panel = Panel(
+            download_summary,
+            title="Download Progress",
+            border_style="cyan",
+            expand=False,
+            width=96,  # Further reduced to ensure it fits within the visualization
+            height=None  # Allow the panel to adjust its height based on content
+        )
+        console = Console(width=98, height=55)  # Reduced console width
+        with console.capture() as capture:
+            console.print(download_panel)
+        download_lines = capture.get().split('\n')
+        download_start_y = 15
+        panel_width = len(max(download_lines, key=len))
+        start_x = max(1, (100 - panel_width) // 2)  # Ensure start_x is at least 1 to avoid left border cut-off
+        for i, line in enumerate(download_lines):
+            for j, char in enumerate(line):
+                if 1 <= start_x + j < 99 and download_start_y + i < 55:  # Ensure we don't write to the rightmost column
+                    visualization[download_start_y + i][start_x + j] = char
+
+
     # Add exo_text at the top in bright yellow
     # Add exo_text at the top in bright yellow
     exo_lines = exo_text.split("\n")
     exo_lines = exo_text.split("\n")
     yellow_style = Style(color="bright_yellow")
     yellow_style = Style(color="bright_yellow")
@@ -168,3 +195,28 @@ class TopologyViz:
 
 
     # Convert to string
     # Convert to string
     return "\n".join("".join(str(char) for char in row) for row in visualization)
     return "\n".join("".join(str(char) for char in row) for row in visualization)
+
+def _generate_download_summary(download_progress) -> Table:
+    summary = Table(show_header=False, box=None, padding=(0, 1))
+    summary.add_column("Info", style="cyan", no_wrap=True)
+    summary.add_column("Progress", style="cyan", no_wrap=True)
+    summary.add_column("Percentage", style="cyan", no_wrap=True)
+
+    title = f"Downloading model ({download_progress.completed_files}/{download_progress.total_files}):"
+    summary.add_row(Text(title, style="bold"))
+    progress_info = f"{pretty_print_bytes(download_progress.downloaded_bytes)} / {pretty_print_bytes(download_progress.total_bytes)} ({pretty_print_bytes_per_second(download_progress.overall_speed)})"
+    summary.add_row(progress_info)
+
+    eta_info = f"ETA: {download_progress.overall_eta}"
+    summary.add_row(eta_info)
+
+    summary.add_row("")  # Empty row for spacing
+
+    for file_path, file_progress in download_progress.file_progress.items():
+      if file_progress.status != "complete":
+        progress = int(file_progress.downloaded / file_progress.total * 20)  # Increased bar width
+        bar = f"[{'=' * progress}{' ' * (20 - progress)}]"
+        percentage = f"{file_progress.downloaded / file_progress.total * 100:.0f}%"
+        summary.add_row(Text(file_path[:20], style="cyan"), bar, percentage)  # Increased file path length
+
+    return summary

+ 1 - 1
main.py

@@ -60,7 +60,7 @@ node.on_token.register("main_log").on_next(lambda _, tokens, __: print(inference
 if args.prometheus_client_port:
 if args.prometheus_client_port:
     from exo.stats.metrics import start_metrics_server
     from exo.stats.metrics import start_metrics_server
     start_metrics_server(node, args.prometheus_client_port)
     start_metrics_server(node, args.prometheus_client_port)
-inference_engine.set_progress_callback(lambda event: asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "current": event.downloaded_bytes, "total": event.total_bytes}))))
+inference_engine.set_progress_callback(lambda event: asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()}))))
 
 
 async def shutdown(signal, loop):
 async def shutdown(signal, loop):
     """Gracefully shutdown the server and close the asyncio loop."""
     """Gracefully shutdown the server and close the asyncio loop."""