|
@@ -103,19 +103,69 @@ async def fetch_file_list(session, repo_id, revision, path=""):
|
|
|
class HFRepoFileProgressEvent:
|
|
|
file_path: str
|
|
|
downloaded: int
|
|
|
+ downloaded_this_session: int
|
|
|
total: int
|
|
|
- speed: float
|
|
|
+ speed: int
|
|
|
eta: timedelta
|
|
|
status: Literal["not_started", "in_progress", "complete"]
|
|
|
|
|
|
+ def to_dict(self):
|
|
|
+ return {
|
|
|
+ "file_path": self.file_path,
|
|
|
+ "downloaded": self.downloaded,
|
|
|
+ "downloaded_this_session": self.downloaded_this_session,
|
|
|
+ "total": self.total,
|
|
|
+ "speed": self.speed,
|
|
|
+ "eta": self.eta.total_seconds(),
|
|
|
+ "status": self.status
|
|
|
+ }
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def from_dict(cls, data):
|
|
|
+ # Convert eta from seconds back to timedelta
|
|
|
+ if 'eta' in data:
|
|
|
+ data['eta'] = timedelta(seconds=data['eta'])
|
|
|
+ return cls(**data)
|
|
|
+
|
|
|
@dataclass
|
|
|
class HFRepoProgressEvent:
|
|
|
completed_files: int
|
|
|
total_files: int
|
|
|
downloaded_bytes: int
|
|
|
+ downloaded_bytes_this_session: int
|
|
|
total_bytes: int
|
|
|
+ overall_speed: int
|
|
|
overall_eta: timedelta
|
|
|
file_progress: Dict[str, HFRepoFileProgressEvent]
|
|
|
+ status: Literal["not_started", "in_progress", "complete"]
|
|
|
+
|
|
|
+ def to_dict(self):
|
|
|
+ return {
|
|
|
+ "completed_files": self.completed_files,
|
|
|
+ "total_files": self.total_files,
|
|
|
+ "downloaded_bytes": self.downloaded_bytes,
|
|
|
+ "downloaded_bytes_this_session": self.downloaded_bytes_this_session,
|
|
|
+ "total_bytes": self.total_bytes,
|
|
|
+ "overall_speed": self.overall_speed,
|
|
|
+ "overall_eta": self.overall_eta.total_seconds(),
|
|
|
+ "file_progress": {k: v.to_dict() for k, v in self.file_progress.items()},
|
|
|
+ "status": self.status
|
|
|
+ }
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def from_dict(cls, data):
|
|
|
+ # Convert overall_eta from seconds back to timedelta
|
|
|
+ if 'overall_eta' in data:
|
|
|
+ data['overall_eta'] = timedelta(seconds=data['overall_eta'])
|
|
|
+
|
|
|
+ # Parse file_progress
|
|
|
+ if 'file_progress' in data:
|
|
|
+ data['file_progress'] = {
|
|
|
+ k: HFRepoFileProgressEvent.from_dict(v)
|
|
|
+ for k, v in data['file_progress'].items()
|
|
|
+ }
|
|
|
+
|
|
|
+ return cls(**data)
|
|
|
|
|
|
HFRepoFileProgressCallback = Callable[[HFRepoFileProgressEvent], Coroutine[Any, Any, None]]
|
|
|
HFRepoProgressCallback = Callable[[HFRepoProgressEvent], Coroutine[Any, Any, None]]
|
|
@@ -143,11 +193,12 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
|
|
|
async with session.get(url, headers=headers) as response:
|
|
|
total_size = int(response.headers.get('Content-Length', 0))
|
|
|
downloaded_size = local_file_size
|
|
|
+ downloaded_this_session = 0
|
|
|
mode = 'ab' if use_range_request else 'wb'
|
|
|
if downloaded_size == total_size:
|
|
|
if DEBUG >= 2: print(f"File already downloaded: {file_path}")
|
|
|
if progress_callback:
|
|
|
- await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, total_size, 0, timedelta(0), "complete"))
|
|
|
+ await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
|
|
|
return
|
|
|
|
|
|
if response.status == 200:
|
|
@@ -170,7 +221,7 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
|
|
|
if downloaded_size == total_size:
|
|
|
if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}")
|
|
|
if progress_callback:
|
|
|
- await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, total_size, 0, timedelta(0), "complete"))
|
|
|
+ await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
|
|
|
return
|
|
|
except ValueError:
|
|
|
if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
|
|
@@ -181,7 +232,7 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
|
|
|
if downloaded_size == total_size:
|
|
|
print(f"File already downloaded: {file_path}")
|
|
|
if progress_callback:
|
|
|
- await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, total_size, 0, timedelta(0), "complete"))
|
|
|
+ await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
|
|
|
return
|
|
|
|
|
|
DOWNLOAD_CHUNK_SIZE = 32768
|
|
@@ -190,13 +241,15 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
|
|
|
async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
|
|
|
f.write(chunk)
|
|
|
downloaded_size += len(chunk)
|
|
|
+ downloaded_this_session += len(chunk)
|
|
|
if progress_callback and total_size:
|
|
|
elapsed_time = (datetime.now() - start_time).total_seconds()
|
|
|
- speed = downloaded_size / elapsed_time if elapsed_time > 0 else 0
|
|
|
+ speed = int(downloaded_this_session / elapsed_time) if elapsed_time > 0 else 0
|
|
|
remaining_size = total_size - downloaded_size
|
|
|
eta = timedelta(seconds=remaining_size / speed) if speed > 0 else timedelta(0)
|
|
|
status = "in_progress" if downloaded_size < total_size else "complete"
|
|
|
- await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, total_size, speed, eta, status))
|
|
|
+ if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}")
|
|
|
+ await progress_callback(HFRepoFileProgressEvent(file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
|
|
|
if DEBUG >= 2: print(f"Downloaded: {file_path}")
|
|
|
|
|
|
async def download_all_files(repo_id: str, revision: str = "main", progress_callback: Optional[HFRepoProgressCallback] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None):
|
|
@@ -229,35 +282,36 @@ async def download_all_files(repo_id: str, revision: str = "main", progress_call
|
|
|
file_list = await fetch_file_list(session, repo_id, revision)
|
|
|
filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
|
|
|
total_files = len(filtered_file_list)
|
|
|
- completed_files = 0
|
|
|
total_bytes = sum(file["size"] for file in filtered_file_list)
|
|
|
- downloaded_bytes = 0
|
|
|
- file_progress: Dict[str, HFRepoFileProgressEvent] = {file["path"]: HFRepoFileProgressEvent(file["path"], 0, file["size"], 0, timedelta(0), "not_started") for file in filtered_file_list}
|
|
|
+ file_progress: Dict[str, HFRepoFileProgressEvent] = {file["path"]: HFRepoFileProgressEvent(file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started") for file in filtered_file_list}
|
|
|
start_time = datetime.now()
|
|
|
|
|
|
- async def download_with_progress(file_info):
|
|
|
- nonlocal completed_files, downloaded_bytes, file_progress
|
|
|
-
|
|
|
+ async def download_with_progress(file_info, progress_state):
|
|
|
async def file_progress_callback(event: HFRepoFileProgressEvent):
|
|
|
- nonlocal downloaded_bytes, file_progress
|
|
|
- downloaded_bytes += event.downloaded - file_progress[event.file_path].downloaded
|
|
|
+ progress_state['downloaded_bytes'] += event.downloaded - file_progress[event.file_path].downloaded
|
|
|
+ progress_state['downloaded_bytes_this_session'] += event.downloaded_this_session - file_progress[event.file_path].downloaded_this_session
|
|
|
file_progress[event.file_path] = event
|
|
|
if progress_callback:
|
|
|
elapsed_time = (datetime.now() - start_time).total_seconds()
|
|
|
- overall_speed = downloaded_bytes / elapsed_time if elapsed_time > 0 else 0
|
|
|
- overall_eta = timedelta(seconds=(total_bytes - downloaded_bytes) / overall_speed) if overall_speed > 0 else timedelta(0)
|
|
|
- await progress_callback(HFRepoProgressEvent(completed_files, total_files, downloaded_bytes, total_bytes, overall_eta, file_progress))
|
|
|
+ overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
|
|
|
+ remaining_bytes = total_bytes - progress_state['downloaded_bytes']
|
|
|
+ overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
|
|
|
+ status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete"
|
|
|
+ await progress_callback(HFRepoProgressEvent(progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
|
|
|
|
|
|
await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
|
|
|
- completed_files += 1
|
|
|
- file_progress[file_info["path"]] = HFRepoFileProgressEvent(file_info["path"], file_info["size"], file_info["size"], 0, timedelta(0), "complete")
|
|
|
+ progress_state['completed_files'] += 1
|
|
|
+ file_progress[file_info["path"]] = HFRepoFileProgressEvent(file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete")
|
|
|
if progress_callback:
|
|
|
elapsed_time = (datetime.now() - start_time).total_seconds()
|
|
|
- overall_speed = downloaded_bytes / elapsed_time if elapsed_time > 0 else 0
|
|
|
- overall_eta = timedelta(seconds=(total_bytes - downloaded_bytes) / overall_speed) if overall_speed > 0 else timedelta(0)
|
|
|
- await progress_callback(HFRepoProgressEvent(completed_files, total_files, downloaded_bytes, total_bytes, overall_eta, file_progress))
|
|
|
-
|
|
|
- tasks = [download_with_progress(file_info) for file_info in filtered_file_list]
|
|
|
+ overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
|
|
|
+ remaining_bytes = total_bytes - progress_state['downloaded_bytes']
|
|
|
+ overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
|
|
|
+ status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
|
|
|
+ await progress_callback(HFRepoProgressEvent(progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
|
|
|
+
|
|
|
+ progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0}
|
|
|
+ tasks = [download_with_progress(file_info, progress_state) for file_info in filtered_file_list]
|
|
|
await asyncio.gather(*tasks)
|
|
|
|
|
|
return snapshot_dir
|