|
@@ -11,7 +11,7 @@ import aiofiles.os as aios
|
|
|
import aiohttp
|
|
|
import aiofiles
|
|
|
from urllib.parse import urljoin
|
|
|
-from typing import Callable, Union, Tuple, Dict
|
|
|
+from typing import Callable, Union, Tuple, Dict, List
|
|
|
import time
|
|
|
from datetime import timedelta
|
|
|
import asyncio
|
|
@@ -82,7 +82,7 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
|
|
|
url = urljoin(base_url, path)
|
|
|
headers = await get_auth_headers()
|
|
|
async with session.get(url, headers=headers) as r:
|
|
|
- assert r.status == 200, r.status
|
|
|
+ assert r.status == 200, f"Failed to download {path} from {url}: {r.status}"
|
|
|
length = int(r.headers.get('content-length', 0))
|
|
|
n_read = 0
|
|
|
async with aiofiles.tempfile.NamedTemporaryFile(dir=target_dir, delete=False) as temp_file:
|
|
@@ -106,9 +106,13 @@ async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]
|
|
|
async with aiofiles.open(index_file, 'r') as f: index_data = json.loads(await f.read())
|
|
|
return index_data.get("weight_map")
|
|
|
|
|
|
-async def resolve_allow_patterns(shard: Shard, inference_engine_classname: str) -> list[str]:
|
|
|
- weight_map = await get_weight_map(get_repo(shard.model_id, inference_engine_classname))
|
|
|
- return get_allow_patterns(weight_map, shard)
|
|
|
+async def resolve_allow_patterns(shard: Shard, inference_engine_classname: str) -> List[str]:
|
|
|
+ try:
|
|
|
+ weight_map = await get_weight_map(get_repo(shard.model_id, inference_engine_classname))
|
|
|
+ return get_allow_patterns(weight_map, shard)
|
|
|
+ except Exception as e:
|
|
|
+ if DEBUG >= 1: print(f"Error getting weight map for {shard.model_id=} and inference engine {inference_engine_classname}: {e}")
|
|
|
+ return ["*"]
|
|
|
|
|
|
async def download_shard(shard: Shard, inference_engine_classname: str, on_progress: AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]], max_parallel_downloads: int = 6, skip_download: bool = False) -> tuple[Path, RepoProgressEvent]:
|
|
|
if DEBUG >= 6 and not skip_download: print(f"Downloading {shard.model_id=} for {inference_engine_classname}")
|