|
|
@@ -92,20 +92,28 @@ async def fetch_file_list(repo_id, revision, path=""):
|
|
|
|
|
|
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=0.5))
|
|
|
async def download_file(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
|
|
|
- if (target_dir/path).exists(): return target_dir/path
|
|
|
- await aios.makedirs((target_dir/path).parent, exist_ok=True)
|
|
|
- base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
|
|
|
- url = urljoin(base_url, path)
|
|
|
- headers = await get_auth_headers()
|
|
|
- async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session:
|
|
|
- async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as r:
|
|
|
- assert r.status == 200, f"Failed to download {path} from {url}: {r.status}"
|
|
|
- length = int(r.headers.get('content-length', 0))
|
|
|
- n_read = 0
|
|
|
- async with aiofiles.tempfile.NamedTemporaryFile(dir=target_dir, delete=False) as temp_file:
|
|
|
- while chunk := await r.content.read(1024 * 1024): on_progress(n_read := n_read + await temp_file.write(chunk), length)
|
|
|
- await aios.rename(temp_file.name, target_dir/path)
|
|
|
- return target_dir/path
|
|
|
+ temp_file_name = None
|
|
|
+ try:
|
|
|
+ if (target_dir/path).exists(): return target_dir/path
|
|
|
+ await aios.makedirs((target_dir/path).parent, exist_ok=True)
|
|
|
+ base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
|
|
|
+ url = urljoin(base_url, path)
|
|
|
+ headers = await get_auth_headers()
|
|
|
+ async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session:
|
|
|
+ async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as r:
|
|
|
+ assert r.status == 200, f"Failed to download {path} from {url}: {r.status}"
|
|
|
+ length = int(r.headers.get('content-length', 0))
|
|
|
+ n_read = 0
|
|
|
+ async with aiofiles.tempfile.NamedTemporaryFile(dir=target_dir, delete=False) as temp_file:
|
|
|
+ temp_file_name = temp_file.name
|
|
|
+ while chunk := await r.content.read(1024 * 1024): on_progress(n_read := n_read + await temp_file.write(chunk), length)
|
|
|
+ await aios.rename(temp_file.name, target_dir/path)
|
|
|
+ return target_dir/path
|
|
|
+ finally:
|
|
|
+ if temp_file_name: # attempt to delete tmp file if it still exists
|
|
|
+ try: await aios.unlink(temp_file_name)
|
|
|
+ except: pass
|
|
|
+
|
|
|
|
|
|
def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_progress: Dict[str, RepoFileProgressEvent], all_start_time: float) -> RepoProgressEvent:
|
|
|
all_total_bytes = sum([p.total for p in file_progress.values()])
|
|
|
@@ -233,4 +241,3 @@ class NewShardDownloader(ShardDownloader):
|
|
|
if DEBUG >= 6: print("Downloaded shards:", downloads)
|
|
|
if any(isinstance(d, Exception) for d in downloads) and DEBUG >= 1: print("Error downloading shards:", [d for d in downloads if isinstance(d, Exception)])
|
|
|
return [d for d in downloads if not isinstance(d, Exception)]
|
|
|
-
|