Browse Source

resumable downloads with integrity checks

Alex Cheema 9 months ago
parent
commit
7034ee0fcb
1 changed files with 48 additions and 22 deletions
  1. 48 22
      exo/download/new_shard_download.py

+ 48 - 22
exo/download/new_shard_download.py

@@ -11,7 +11,7 @@ import aiofiles.os as aios
 import aiohttp
 import aiofiles
 from urllib.parse import urljoin
-from typing import Callable, Union, Tuple, Dict, List
+from typing import Callable, Union, Tuple, Dict, List, Optional, Literal
 import time
 from datetime import timedelta
 import asyncio
@@ -19,7 +19,8 @@ import json
 import traceback
 import shutil
 import tempfile
-from tenacity import retry, stop_after_attempt, wait_exponential
+import hashlib
+from tenacity import retry, stop_after_attempt, wait_fixed
 
 def exo_home() -> Path:
   return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo"))
@@ -69,7 +70,7 @@ async def seed_models(seed_dir: Union[str, Path]):
           print(f"Error seeding model {path} to {dest_path}")
           traceback.print_exc()
 
-@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=0.5))
+@retry(stop=stop_after_attempt(30), wait=wait_fixed(1))
 async def fetch_file_list(repo_id, revision, path=""):
   api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
   url = f"{api_url}/{path}" if path else api_url
@@ -90,29 +91,54 @@ async def fetch_file_list(repo_id, revision, path=""):
       else:
         raise Exception(f"Failed to fetch file list: {response.status}")
 
-@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=0.5))
+async def calc_hash(path: Path, type: Literal["sha1", "sha256"] = "sha1") -> str:
+  hash = hashlib.sha1() if type == "sha1" else hashlib.sha256()
+  if type == "sha1":
+    header = f"blob {(await aios.stat(path)).st_size}\0".encode()
+    hash.update(header)
+  async with aiofiles.open(path, 'rb') as f:
+    while chunk := await f.read(1024 * 1024):
+      hash.update(chunk)
+  return hash.hexdigest()
+
+async def file_meta(repo_id: str, revision: str, path: str) -> Tuple[int, str]:
+  url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
+  headers = await get_auth_headers()
+  async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session:
+    async with session.head(url, headers=headers) as r:
+      content_length = int(r.headers.get('content-length', 0))
+      etag = r.headers.get('X-Linked-ETag') or r.headers.get('ETag') or r.headers.get('Etag')
+      assert content_length > 0, f"No content length for {url}"
+      assert etag is not None, f"No remote hash for {url}"
+      if  (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"): etag = etag[1:-1]
+      return content_length, etag
+
+@retry(stop=stop_after_attempt(30), wait=wait_fixed(1))
 async def download_file(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
-  temp_file_name = None
-  try:
-    if (target_dir/path).exists(): return target_dir/path
-    await aios.makedirs((target_dir/path).parent, exist_ok=True)
-    base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
-    url = urljoin(base_url, path)
+  if await aios.path.exists(target_dir/path): return target_dir/path
+  await aios.makedirs((target_dir/path).parent, exist_ok=True)
+  length, remote_hash = await file_meta(repo_id, revision, path)
+  partial_path = target_dir/f"{path}.partial"
+  resume_byte_pos = (await aios.stat(partial_path)).st_size if (await aios.path.exists(partial_path)) else None
+  if resume_byte_pos != length:
+    url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
     headers = await get_auth_headers()
+    if resume_byte_pos: headers['Range'] = f'bytes={resume_byte_pos}-'
+    n_read = 0
     async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session:
       async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as r:
-        assert r.status == 200, f"Failed to download {path} from {url}: {r.status}"
-        length = int(r.headers.get('content-length', 0))
-        n_read = 0
-        async with aiofiles.tempfile.NamedTemporaryFile(dir=target_dir, delete=False) as temp_file:
-          temp_file_name = temp_file.name
-          while chunk := await r.content.read(1024 * 1024): on_progress(n_read := n_read + await temp_file.write(chunk), length)
-          await aios.rename(temp_file.name, target_dir/path)
-        return target_dir/path
-  finally:
-    if temp_file_name: # attempt to delete tmp file if it still exists
-      try: await aios.unlink(temp_file_name)
-      except: pass
+        assert r.status in [200, 206], f"Failed to download {path} from {url}: {r.status}"
+        async with aiofiles.open(partial_path, 'ab' if resume_byte_pos else 'wb') as f:
+          while chunk := await r.content.read(1024 * 1024): on_progress(n_read := n_read + await f.write(chunk), length)
+
+  final_hash = await calc_hash(partial_path, type="sha256" if len(remote_hash) == 64 else "sha1")
+  integrity = final_hash == remote_hash
+  if not integrity:
+    try: await aios.remove(partial_path)
+    except Exception as e: print(f"Error removing partial file {partial_path}: {e}")
+    raise Exception(f"Downloaded file {target_dir/path} has hash {final_hash} but remote hash is {remote_hash}")
+  await aios.rename(partial_path, target_dir/path)
+  return target_dir/path
 
 
 def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_progress: Dict[str, RepoFileProgressEvent], all_start_time: float) -> RepoProgressEvent: