Sfoglia il codice sorgente

special case when a model doesnt have a model index file, then use wildcard for allow_patterns

Alex Cheema 7 mesi fa
parent
commit
277d63d860
1 ha cambiato i file con 9 aggiunte e 5 eliminazioni
  1. 9 5
      exo/download/new_shard_download.py

+ 9 - 5
exo/download/new_shard_download.py

@@ -11,7 +11,7 @@ import aiofiles.os as aios
 import aiohttp
 import aiofiles
 from urllib.parse import urljoin
-from typing import Callable, Union, Tuple, Dict
+from typing import Callable, Union, Tuple, Dict, List
 import time
 from datetime import timedelta
 import asyncio
@@ -82,7 +82,7 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
   url = urljoin(base_url, path)
   headers = await get_auth_headers()
   async with session.get(url, headers=headers) as r:
-    assert r.status == 200, r.status
+    assert r.status == 200, f"Failed to download {path} from {url}: {r.status}"
     length = int(r.headers.get('content-length', 0))
     n_read = 0
     async with aiofiles.tempfile.NamedTemporaryFile(dir=target_dir, delete=False) as temp_file:
@@ -106,9 +106,13 @@ async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]
     async with aiofiles.open(index_file, 'r') as f: index_data = json.loads(await f.read())
     return index_data.get("weight_map")
 
-async def resolve_allow_patterns(shard: Shard, inference_engine_classname: str) -> list[str]:
-  weight_map = await get_weight_map(get_repo(shard.model_id, inference_engine_classname))
-  return get_allow_patterns(weight_map, shard)
+async def resolve_allow_patterns(shard: Shard, inference_engine_classname: str) -> List[str]:
+  try:
+    weight_map = await get_weight_map(get_repo(shard.model_id, inference_engine_classname))
+    return get_allow_patterns(weight_map, shard)
+  except Exception as e:
+    if DEBUG >= 1: print(f"Error getting weight map for {shard.model_id=} and inference engine {inference_engine_classname}: {e}")
+    return ["*"]
 
 async def download_shard(shard: Shard, inference_engine_classname: str, on_progress: AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]], max_parallel_downloads: int = 6, skip_download: bool = False) -> tuple[Path, RepoProgressEvent]:
   if DEBUG >= 6 and not skip_download: print(f"Downloading {shard.model_id=} for {inference_engine_classname}")