Prechádzať zdrojové kódy

ensure exo dir on start, retry with exp backoff on file downloads

Alex Cheema 3 mesiacov pred
rodič
commit
6b1c8635fc
2 zmenil súbory, kde vykonal 62 pridanie a 51 odobranie
  1. 53 44
      exo/download/new_shard_download.py
  2. 9 7
      exo/main.py

+ 53 - 44
exo/download/new_shard_download.py

@@ -19,6 +19,7 @@ import json
 import traceback
 import shutil
 import tempfile
+from tenacity import retry, stop_after_attempt, wait_exponential
 
 def exo_home() -> Path:
   return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo"))
@@ -26,6 +27,10 @@ def exo_home() -> Path:
 def exo_tmp() -> Path:
   return Path(tempfile.gettempdir())/"exo"
 
+async def ensure_exo_home() -> Path:
+  await aios.makedirs(exo_home(), exist_ok=True)
+  return exo_home()
+
 async def ensure_exo_tmp() -> Path:
   await aios.makedirs(exo_tmp(), exist_ok=True)
   return exo_tmp()
@@ -64,39 +69,42 @@ async def seed_models(seed_dir: Union[str, Path]):
           print(f"Error seeding model {path} to {dest_path}")
           traceback.print_exc()
 
-async def fetch_file_list(session, repo_id, revision, path=""):
+async def fetch_file_list(repo_id, revision, path=""):
   api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
   url = f"{api_url}/{path}" if path else api_url
 
   headers = await get_auth_headers()
-  async with session.get(url, headers=headers) as response:
-    if response.status == 200:
-      data = await response.json()
-      files = []
-      for item in data:
-        if item["type"] == "file":
-          files.append({"path": item["path"], "size": item["size"]})
-        elif item["type"] == "directory":
-          subfiles = await fetch_file_list(session, repo_id, revision, item["path"])
-          files.extend(subfiles)
-      return files
-    else:
-      raise Exception(f"Failed to fetch file list: {response.status}")
+  async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30, connect=10, sock_read=1800, sock_connect=60)) as session:
+    async with session.get(url, headers=headers) as response:
+      if response.status == 200:
+        data = await response.json()
+        files = []
+        for item in data:
+          if item["type"] == "file":
+            files.append({"path": item["path"], "size": item["size"]})
+          elif item["type"] == "directory":
+            subfiles = await fetch_file_list(session, repo_id, revision, item["path"])
+            files.extend(subfiles)
+        return files
+      else:
+        raise Exception(f"Failed to fetch file list: {response.status}")
 
-async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
+@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=0.5))
+async def download_file(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
   if (target_dir/path).exists(): return target_dir/path
   await aios.makedirs((target_dir/path).parent, exist_ok=True)
   base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
   url = urljoin(base_url, path)
   headers = await get_auth_headers()
-  async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as r:
-    assert r.status == 200, f"Failed to download {path} from {url}: {r.status}"
-    length = int(r.headers.get('content-length', 0))
-    n_read = 0
-    async with aiofiles.tempfile.NamedTemporaryFile(dir=target_dir, delete=False) as temp_file:
-      while chunk := await r.content.read(1024 * 1024): on_progress(n_read := n_read + await temp_file.write(chunk), length)
-      await aios.rename(temp_file.name, target_dir/path)
-    return target_dir/path
+  async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session:
+    async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as r:
+      assert r.status == 200, f"Failed to download {path} from {url}: {r.status}"
+      length = int(r.headers.get('content-length', 0))
+      n_read = 0
+      async with aiofiles.tempfile.NamedTemporaryFile(dir=target_dir, delete=False) as temp_file:
+        while chunk := await r.content.read(1024 * 1024): on_progress(n_read := n_read + await temp_file.write(chunk), length)
+        await aios.rename(temp_file.name, target_dir/path)
+      return target_dir/path
 
 def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_progress: Dict[str, RepoFileProgressEvent], all_start_time: float) -> RepoProgressEvent:
   all_total_bytes = sum([p.total for p in file_progress.values()])
@@ -110,10 +118,9 @@ def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_prog
 
 async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]:
   target_dir = await ensure_exo_tmp()/repo_id.replace("/", "--")
-  async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=10, sock_read=1800, sock_connect=10)) as session:
-    index_file = await download_file(session, repo_id, revision, "model.safetensors.index.json", target_dir)
-    async with aiofiles.open(index_file, 'r') as f: index_data = json.loads(await f.read())
-    return index_data.get("weight_map")
+  index_file = await download_file(repo_id, revision, "model.safetensors.index.json", target_dir)
+  async with aiofiles.open(index_file, 'r') as f: index_data = json.loads(await f.read())
+  return index_data.get("weight_map")
 
 async def resolve_allow_patterns(shard: Shard, inference_engine_classname: str) -> List[str]:
   try:
@@ -138,30 +145,32 @@ async def download_shard(shard: Shard, inference_engine_classname: str, on_progr
   if DEBUG >= 2: print(f"Downloading {shard.model_id=} with {allow_patterns=}")
 
   all_start_time = time.time()
-  async with aiohttp.ClientSession() as session:
-    file_list = await fetch_file_list(session, repo_id, revision)
-    filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, key=lambda x: x["path"]))
-    file_progress: Dict[str, RepoFileProgressEvent] = {}
-    def on_progress_wrapper(file: dict, curr_bytes: int, total_bytes: int):
-      start_time = file_progress[file["path"]].start_time if file["path"] in file_progress else time.time()
-      downloaded_this_session = file_progress[file["path"]].downloaded_this_session + (curr_bytes - file_progress[file["path"]].downloaded) if file["path"] in file_progress else curr_bytes
-      speed = downloaded_this_session / (time.time() - start_time)
-      eta = timedelta(seconds=(total_bytes - curr_bytes) / speed)
-      file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], curr_bytes, downloaded_this_session, total_bytes, speed, eta, "complete" if curr_bytes == total_bytes else "in_progress", start_time)
-      on_progress.trigger_all(shard, calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time))
-      if DEBUG >= 6: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}")
-    for file in filtered_file_list:
-      downloaded_bytes = (await aios.stat(target_dir/file["path"])).st_size if await aios.path.exists(target_dir/file["path"]) else 0
-      file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], downloaded_bytes, 0, file["size"], 0, timedelta(0), "complete" if downloaded_bytes == file["size"] else "not_started", time.time())
+  file_list = await fetch_file_list(repo_id, revision)
+  filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, key=lambda x: x["path"]))
+  file_progress: Dict[str, RepoFileProgressEvent] = {}
+  def on_progress_wrapper(file: dict, curr_bytes: int, total_bytes: int):
+    start_time = file_progress[file["path"]].start_time if file["path"] in file_progress else time.time()
+    downloaded_this_session = file_progress[file["path"]].downloaded_this_session + (curr_bytes - file_progress[file["path"]].downloaded) if file["path"] in file_progress else curr_bytes
+    speed = downloaded_this_session / (time.time() - start_time)
+    eta = timedelta(seconds=(total_bytes - curr_bytes) / speed)
+    file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], curr_bytes, downloaded_this_session, total_bytes, speed, eta, "complete" if curr_bytes == total_bytes else "in_progress", start_time)
+    on_progress.trigger_all(shard, calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time))
+    if DEBUG >= 6: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}")
+  for file in filtered_file_list:
+    downloaded_bytes = (await aios.stat(target_dir/file["path"])).st_size if await aios.path.exists(target_dir/file["path"]) else 0
+    file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], downloaded_bytes, 0, file["size"], 0, timedelta(0), "complete" if downloaded_bytes == file["size"] else "not_started", time.time())
 
     semaphore = asyncio.Semaphore(max_parallel_downloads)
     async def download_with_semaphore(file):
       async with semaphore:
-        await download_file(session, repo_id, revision, file["path"], target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes))
+        await download_file(repo_id, revision, file["path"], target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes))
     if not skip_download: await asyncio.gather(*[download_with_semaphore(file) for file in filtered_file_list])
     final_repo_progress = calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time)
     on_progress.trigger_all(shard, final_repo_progress)
-    return target_dir, final_repo_progress
+    if gguf := next((f for f in filtered_file_list if f["path"].endswith(".gguf")), None):
+      return target_dir/gguf["path"], final_repo_progress
+    else:
+      return target_dir, final_repo_progress
 
 def new_shard_downloader() -> ShardDownloader:
   return SingletonShardDownloader(CachedShardDownloader(NewShardDownloader()))

+ 9 - 7
exo/main.py

@@ -21,7 +21,7 @@ from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWe
 from exo.api import ChatGPTAPI
 from exo.download.shard_download import ShardDownloader, NoopShardDownloader
 from exo.download.download_progress import RepoProgressEvent
-from exo.download.new_shard_download import new_shard_downloader, has_exo_home_read_access, has_exo_home_write_access, exo_home, seed_models
+from exo.download.new_shard_download import new_shard_downloader, has_exo_home_read_access, has_exo_home_write_access, ensure_exo_home, seed_models
 from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses_and_interfaces, terminal_link, shutdown
 from exo.inference.shard import Shard
 from exo.inference.inference_engine import get_inference_engine
@@ -306,12 +306,8 @@ async def train_model_cli(node: Node, model_name, dataloader, batch_size, iters,
       await hold_outstanding(node)
   await hold_outstanding(node)
 
-
-async def main():
-  loop = asyncio.get_running_loop()
-
-  # Check exo directory permissions
-  home, has_read, has_write = exo_home(), await has_exo_home_read_access(), await has_exo_home_write_access()
+async def check_exo_home():
+  home, has_read, has_write = await ensure_exo_home(), await has_exo_home_read_access(), await has_exo_home_write_access()
   if DEBUG >= 1: print(f"exo home directory: {home}")
   print(f"{has_read=}, {has_write=}")
   if not has_read or not has_write:
@@ -322,6 +318,12 @@ async def main():
           {"❌ No write access" if not has_write else ""}
           """)
 
+async def main():
+  loop = asyncio.get_running_loop()
+
+  try: await check_exo_home()
+  except Exception as e: print(f"Error checking exo home directory: {e}")
+
   if not args.models_seed_dir is None:
     try:
       models_seed_dir = clean_path(args.models_seed_dir)