Browse Source

cleaner download progress ui

Alex Cheema 1 year ago
parent
commit
d9f232b313
3 changed files with 20 additions and 10 deletions
  1. 8 0
      exo/download/download_progress.py
  2. 8 8
      exo/download/hf/hf_helpers.py
  3. 4 2
      exo/viz/topology_viz.py

+ 8 - 0
exo/download/download_progress.py

@@ -4,6 +4,8 @@ from datetime import timedelta
 
 @dataclass
 class RepoFileProgressEvent:
+    repo_id: str
+    repo_revision: str
     file_path: str
     downloaded: int
     downloaded_this_session: int
@@ -14,6 +16,8 @@ class RepoFileProgressEvent:
 
     def to_dict(self):
         return {
+            "repo_id": self.repo_id,
+            "repo_revision": self.repo_revision,
             "file_path": self.file_path,
             "downloaded": self.downloaded,
             "downloaded_this_session": self.downloaded_this_session,
@@ -32,6 +36,8 @@ class RepoFileProgressEvent:
 
 @dataclass
 class RepoProgressEvent:
+    repo_id: str
+    repo_revision: str
     completed_files: int
     total_files: int
     downloaded_bytes: int
@@ -44,6 +50,8 @@ class RepoProgressEvent:
 
     def to_dict(self):
         return {
+            "repo_id": self.repo_id,
+            "repo_revision": self.repo_revision,
             "completed_files": self.completed_files,
             "total_files": self.total_files,
             "downloaded_bytes": self.downloaded_bytes,

+ 8 - 8
exo/download/hf/hf_helpers.py

@@ -122,7 +122,7 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
         if downloaded_size == total_size:
             if DEBUG >= 2: print(f"File already downloaded: {file_path}")
             if progress_callback:
-                await progress_callback(RepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
+                await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
             return
 
         if response.status == 200:
@@ -145,7 +145,7 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
                 if downloaded_size == total_size:
                     if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}")
                     if progress_callback:
-                        await progress_callback(RepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
+                        await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
                     return
             except ValueError:
                 if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
@@ -156,7 +156,7 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
         if downloaded_size == total_size:
             print(f"File already downloaded: {file_path}")
             if progress_callback:
-                await progress_callback(RepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
+                await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
             return
 
         DOWNLOAD_CHUNK_SIZE = 32768
@@ -173,7 +173,7 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
                     eta = timedelta(seconds=remaining_size / speed) if speed > 0 else timedelta(0)
                     status = "in_progress" if downloaded_size < total_size else "complete"
                     if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}")
-                    await progress_callback(RepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
+                    await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
         if DEBUG >= 2: print(f"Downloaded: {file_path}")
 
 async def download_repo_files(repo_id: str, revision: str = "main", progress_callback: Optional[RepoProgressCallback] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None) -> Path:
@@ -207,7 +207,7 @@ async def download_repo_files(repo_id: str, revision: str = "main", progress_cal
         filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
         total_files = len(filtered_file_list)
         total_bytes = sum(file["size"] for file in filtered_file_list)
-        file_progress: Dict[str, RepoFileProgressEvent] = {file["path"]: RepoFileProgressEvent(file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started") for file in filtered_file_list}
+        file_progress: Dict[str, RepoFileProgressEvent] = {file["path"]: RepoFileProgressEvent(repo_id, revision, file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started") for file in filtered_file_list}
         start_time = datetime.now()
 
         async def download_with_progress(file_info, progress_state):
@@ -221,18 +221,18 @@ async def download_repo_files(repo_id: str, revision: str = "main", progress_cal
                     remaining_bytes = total_bytes - progress_state['downloaded_bytes']
                     overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
                     status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete"
-                    await progress_callback(RepoProgressEvent(progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
+                    await progress_callback(RepoProgressEvent(repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
 
             await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
             progress_state['completed_files'] += 1
-            file_progress[file_info["path"]] = RepoFileProgressEvent(file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete")
+            file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete")
             if progress_callback:
                 elapsed_time = (datetime.now() - start_time).total_seconds()
                 overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
                 remaining_bytes = total_bytes - progress_state['downloaded_bytes']
                 overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
                 status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
-                await progress_callback(RepoProgressEvent(progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
+                await progress_callback(RepoProgressEvent(repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
 
         progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0}
         tasks = [download_with_progress(file_info, progress_state) for file_info in filtered_file_list]

+ 4 - 2
exo/viz/topology_viz.py

@@ -226,9 +226,11 @@ class TopologyViz:
     summary.add_row(Text("Other Nodes Download Progress:", style="bold"))
     for node_id, progress in self.node_download_progress.items():
         if node_id != self.node_id:
-            truncated_id = node_id[:8] + "..." if len(node_id) > 8 else node_id
+            device = self.topology.nodes.get(node_id)
+            partition = next((p for p in self.partitions if p.node_id == node_id), None)
+            partition_info = f"[{partition.start:.2f}-{partition.end:.2f}]" if partition else ""
             percentage = progress.downloaded_bytes / progress.total_bytes * 100 if progress.total_bytes > 0 else 0
             speed = pretty_print_bytes_per_second(progress.overall_speed)
-            summary.add_row(f"{truncated_id}: {percentage:.1f}% ({speed})")
+            summary.add_row(f"{device.model if device else 'Unknown Device'} {device.memory // 1024 if device else '?'}GB {partition_info}: {percentage:.1f}% ({speed} ETA: {progress.overall_eta})")
 
     return summary