Browse Source

beautiful download

Alex Cheema 4 months ago
parent
commit
2c0d17c336

+ 2 - 3
exo/api/chatgpt_api.py

@@ -276,9 +276,8 @@ class ChatGPTAPI:
     try:
       response = web.StreamResponse(status=200, reason='OK', headers={ 'Content-Type': 'text/event-stream', 'Cache-Control': 'no-cache', 'Connection': 'keep-alive' })
       await response.prepare(request)
-      downloads = await self.node.shard_downloader.get_shard_download_status(self.inference_engine_classname)
-      for (path, d) in downloads:
-        model_data = { d.shard.model_id: { "downloaded": d.downloaded_bytes == d.total_bytes, "download_percentage": 100 if d.downloaded_bytes == d.total_bytes else 100 * float(d.downloaded_bytes) / float(d.total_bytes), "total_size": d.total_bytes, "total_downloaded": d.downloaded_bytes } }
+      async for path, s in self.node.shard_downloader.get_shard_download_status(self.inference_engine_classname):
+        model_data = { s.shard.model_id: { "downloaded": s.downloaded_bytes == s.total_bytes, "download_percentage": 100 if s.downloaded_bytes == s.total_bytes else 100 * float(s.downloaded_bytes) / float(s.total_bytes), "total_size": s.total_bytes, "total_downloaded": s.downloaded_bytes } }
         await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
       await response.write(b"data: [DONE]\n\n")
       return response

+ 39 - 18
exo/download/new_shard_download.py

@@ -11,7 +11,7 @@ import aiofiles.os as aios
 import aiohttp
 import aiofiles
 from urllib.parse import urljoin
-from typing import Callable, Union, Tuple, Dict, List, Optional, Literal
+from typing import Callable, Union, Tuple, Dict, List, Optional, Literal, AsyncIterator
 import time
 from datetime import timedelta
 import asyncio
@@ -20,7 +20,7 @@ import traceback
 import shutil
 import tempfile
 import hashlib
-from tenacity import retry, stop_after_attempt, wait_fixed
+from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_not_exception_type
 
 def exo_home() -> Path:
   return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo"))
@@ -70,8 +70,17 @@ async def seed_models(seed_dir: Union[str, Path]):
           print(f"Error seeding model {path} to {dest_path}")
           traceback.print_exc()
 
+async def fetch_file_list_with_cache(repo_id: str, revision: str = "main") -> List[Dict[str, Union[str, int]]]:
+  cache_file = (await ensure_exo_tmp())/f"{repo_id.replace('/', '--')}--{revision}--file_list.json"
+  if await aios.path.exists(cache_file):
+    async with aiofiles.open(cache_file, 'r') as f: return json.loads(await f.read())
+  file_list = await fetch_file_list(repo_id, revision)
+  await aios.makedirs(cache_file.parent, exist_ok=True)
+  async with aiofiles.open(cache_file, 'w') as f: await f.write(json.dumps(file_list))
+  return file_list
+
 @retry(stop=stop_after_attempt(30), wait=wait_fixed(1))
-async def fetch_file_list(repo_id, revision, path=""):
+async def fetch_file_list(repo_id: str, revision: str = "main", path: str = "") -> List[Dict[str, Union[str, int]]]:
   api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
   url = f"{api_url}/{path}" if path else api_url
 
@@ -106,14 +115,14 @@ async def file_meta(repo_id: str, revision: str, path: str) -> Tuple[int, str]:
   headers = await get_auth_headers()
   async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session:
     async with session.head(url, headers=headers) as r:
-      content_length = int(r.headers.get('content-length', 0))
+      content_length = int(r.headers.get('x-linked-size') or r.headers.get('content-length') or 0)
       etag = r.headers.get('X-Linked-ETag') or r.headers.get('ETag') or r.headers.get('Etag')
       assert content_length > 0, f"No content length for {url}"
       assert etag is not None, f"No remote hash for {url}"
       if  (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"): etag = etag[1:-1]
       return content_length, etag
 
-@retry(stop=stop_after_attempt(30), wait=wait_fixed(1))
+@retry(stop=stop_after_attempt(30), wait=wait_fixed(1), retry=retry_if_not_exception_type(FileNotFoundError))
 async def download_file(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
   if await aios.path.exists(target_dir/path): return target_dir/path
   await aios.makedirs((target_dir/path).parent, exist_ok=True)
@@ -124,9 +133,10 @@ async def download_file(repo_id: str, revision: str, path: str, target_dir: Path
     url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
     headers = await get_auth_headers()
     if resume_byte_pos: headers['Range'] = f'bytes={resume_byte_pos}-'
-    n_read = 0
+    n_read = resume_byte_pos or 0
     async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session:
       async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as r:
+        if r.status == 404: raise FileNotFoundError(f"File not found: {url}")
         assert r.status in [200, 206], f"Failed to download {path} from {url}: {r.status}"
         async with aiofiles.open(partial_path, 'ab' if resume_byte_pos else 'wb') as f:
           while chunk := await r.content.read(1024 * 1024): on_progress(n_read := n_read + await f.write(chunk), length)
@@ -152,7 +162,7 @@ def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_prog
   return RepoProgressEvent(shard, repo_id, revision, len([p for p in file_progress.values() if p.downloaded == p.total]), len(file_progress), all_downloaded_bytes, all_downloaded_bytes_this_session, all_total_bytes, all_speed, all_eta, file_progress, status)
 
 async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]:
-  target_dir = await ensure_exo_tmp()/repo_id.replace("/", "--")
+  target_dir = (await ensure_exo_tmp())/repo_id.replace("/", "--")
   index_file = await download_file(repo_id, revision, "model.safetensors.index.json", target_dir)
   async with aiofiles.open(index_file, 'r') as f: index_data = json.loads(await f.read())
   return index_data.get("weight_map")
@@ -166,6 +176,12 @@ async def resolve_allow_patterns(shard: Shard, inference_engine_classname: str)
     if DEBUG >= 1: traceback.print_exc()
     return ["*"]
 
+async def get_downloaded_size(path: Path) -> int:
+  partial_path = path.with_suffix(path.suffix + ".partial")
+  if await aios.path.exists(path): return (await aios.stat(path)).st_size
+  if await aios.path.exists(partial_path): return (await aios.stat(partial_path)).st_size
+  return 0
+
 async def download_shard(shard: Shard, inference_engine_classname: str, on_progress: AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]], max_parallel_downloads: int = 6, skip_download: bool = False) -> tuple[Path, RepoProgressEvent]:
   if DEBUG >= 2 and not skip_download: print(f"Downloading {shard.model_id=} for {inference_engine_classname}")
   repo_id = get_repo(shard.model_id, inference_engine_classname)
@@ -180,7 +196,7 @@ async def download_shard(shard: Shard, inference_engine_classname: str, on_progr
   if DEBUG >= 2: print(f"Downloading {shard.model_id=} with {allow_patterns=}")
 
   all_start_time = time.time()
-  file_list = await fetch_file_list(repo_id, revision)
+  file_list = await fetch_file_list_with_cache(repo_id, revision)
   filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, key=lambda x: x["path"]))
   file_progress: Dict[str, RepoFileProgressEvent] = {}
   def on_progress_wrapper(file: dict, curr_bytes: int, total_bytes: int):
@@ -192,7 +208,7 @@ async def download_shard(shard: Shard, inference_engine_classname: str, on_progr
     on_progress.trigger_all(shard, calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time))
     if DEBUG >= 6: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}")
   for file in filtered_file_list:
-    downloaded_bytes = (await aios.stat(target_dir/file["path"])).st_size if await aios.path.exists(target_dir/file["path"]) else 0
+    downloaded_bytes = await get_downloaded_size(target_dir/file["path"])
     file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], downloaded_bytes, 0, file["size"], 0, timedelta(0), "complete" if downloaded_bytes == file["size"] else "not_started", time.time())
 
   semaphore = asyncio.Semaphore(max_parallel_downloads)
@@ -225,8 +241,9 @@ class SingletonShardDownloader(ShardDownloader):
     finally:
       if shard in self.active_downloads and self.active_downloads[shard].done(): del self.active_downloads[shard]
 
-  async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
-    return await self.shard_downloader.get_shard_download_status(inference_engine_name)
+  async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
+    async for path, status in self.shard_downloader.get_shard_download_status(inference_engine_name):
+      yield path, status
 
 class CachedShardDownloader(ShardDownloader):
   def __init__(self, shard_downloader: ShardDownloader):
@@ -246,8 +263,9 @@ class CachedShardDownloader(ShardDownloader):
     self.cache[(inference_engine_name, shard)] = target_dir
     return target_dir
 
-  async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
-    return await self.shard_downloader.get_shard_download_status(inference_engine_name)
+  async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
+    async for path, status in self.shard_downloader.get_shard_download_status(inference_engine_name):
+      yield path, status
 
 class NewShardDownloader(ShardDownloader):
   def __init__(self):
@@ -261,9 +279,12 @@ class NewShardDownloader(ShardDownloader):
     target_dir, _ = await download_shard(shard, inference_engine_name, self.on_progress)
     return target_dir
 
-  async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
+  async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
     if DEBUG >= 2: print("Getting shard download status for", inference_engine_name)
-    downloads = await asyncio.gather(*[download_shard(build_full_shard(model_id, inference_engine_name), inference_engine_name, self.on_progress, skip_download=True) for model_id in get_supported_models([[inference_engine_name]])], return_exceptions=True)
-    if DEBUG >= 6: print("Downloaded shards:", downloads)
-    if any(isinstance(d, Exception) for d in downloads) and DEBUG >= 1: print("Error downloading shards:", [d for d in downloads if isinstance(d, Exception)])
-    return [d for d in downloads if not isinstance(d, Exception)]
+    tasks = [download_shard(build_full_shard(model_id, inference_engine_name), inference_engine_name, self.on_progress, skip_download=True) for model_id in get_supported_models([[inference_engine_name]])]
+    for task in asyncio.as_completed(tasks):
+      try:
+        path, progress = await task
+        yield (path, progress)
+      except Exception as e:
+        print("Error downloading shard:", e)

+ 4 - 4
exo/download/shard_download.py

@@ -1,5 +1,5 @@
 from abc import ABC, abstractmethod
-from typing import Optional, Tuple, Dict
+from typing import Optional, Tuple, Dict, AsyncIterator
 from pathlib import Path
 from exo.inference.shard import Shard
 from exo.download.download_progress import RepoProgressEvent
@@ -27,7 +27,7 @@ class ShardDownloader(ABC):
     pass
 
   @abstractmethod
-  async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
+  async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
     """Get the download status of shards.
     
     Returns:
@@ -45,5 +45,5 @@ class NoopShardDownloader(ShardDownloader):
   def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
     return AsyncCallbackSystem()
 
-  async def get_shard_download_status(self, inference_engine_name: str) -> Optional[Dict[str, float]]:
-    return None
+  async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
+    if False: yield

+ 3 - 4
exo/download/test_new_shard_download.py

@@ -1,14 +1,13 @@
-from exo.download.new_shard_download import download_shard, NewShardDownloader
+from exo.download.new_shard_download import NewShardDownloader
 from exo.inference.shard import Shard
-from pathlib import Path
 import asyncio
 
 async def test_new_shard_download():
   shard_downloader = NewShardDownloader()
   shard_downloader.on_progress.register("test").on_next(lambda shard, event: print(shard, event))
   await shard_downloader.ensure_shard(Shard(model_id="llama-3.2-1b", start_layer=0, end_layer=0, n_layers=16), "MLXDynamicShardInferenceEngine")
-  download_statuses = await shard_downloader.get_shard_download_status("MLXDynamicShardInferenceEngine")
-  print({k: v for k, v in download_statuses if v.downloaded_bytes > 0})
+  async for path, shard_status in shard_downloader.get_shard_download_status("MLXDynamicShardInferenceEngine"):
+    print("Shard download status:", path, shard_status)
 
 if __name__ == "__main__":
   asyncio.run(test_new_shard_download())