|
@@ -1,7 +1,7 @@
|
|
|
import asyncio
|
|
|
import traceback
|
|
|
from pathlib import Path
|
|
|
-from typing import Dict, List, Tuple, Optional
|
|
|
+from typing import Dict, List, Tuple, Optional, Union
|
|
|
from exo.inference.shard import Shard
|
|
|
from exo.download.shard_download import ShardDownloader
|
|
|
from exo.download.download_progress import RepoProgressEvent
|
|
@@ -90,7 +90,7 @@ class HFShardDownloader(ShardDownloader):
|
|
|
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
|
|
|
return self._on_progress
|
|
|
|
|
|
- async def get_shard_download_status(self) -> Optional[Dict[str, float]]:
|
|
|
+ async def get_shard_download_status(self) -> Optional[Dict[str, Union[float, int]]]:
|
|
|
if not self.current_shard or not self.current_repo_id:
|
|
|
if DEBUG >= 2:
|
|
|
print(f"No current shard or repo_id set: {self.current_shard=} {self.current_repo_id=}")
|
|
@@ -144,6 +144,12 @@ class HFShardDownloader(ShardDownloader):
|
|
|
status["overall"] = (downloaded_bytes / total_bytes) * 100
|
|
|
else:
|
|
|
status["overall"] = 0
|
|
|
+
|
|
|
+ # Add total size in bytes
|
|
|
+ status["total_size"] = total_bytes
|
|
|
+ if status["overall"] != 100:
|
|
|
+ status["total_downloaded"] = downloaded_bytes
|
|
|
+
|
|
|
|
|
|
if DEBUG >= 2:
|
|
|
print(f"Download calculation for {self.current_repo_id}:")
|