|
@@ -11,7 +11,7 @@ import aiofiles.os as aios
|
|
|
import aiohttp
|
|
|
import aiofiles
|
|
|
from urllib.parse import urljoin
|
|
|
-from typing import Optional, Callable, Union, Tuple, Dict
|
|
|
+from typing import Callable, Union, Tuple, Dict
|
|
|
import time
|
|
|
from datetime import timedelta
|
|
|
import asyncio
|
|
@@ -111,7 +111,7 @@ async def resolve_allow_patterns(shard: Shard, inference_engine_classname: str)
|
|
|
return get_allow_patterns(weight_map, shard)
|
|
|
|
|
|
async def download_shard(shard: Shard, inference_engine_classname: str, on_progress: AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]], max_parallel_downloads: int = 6, skip_download: bool = False) -> tuple[Path, RepoProgressEvent]:
|
|
|
- if DEBUG >= 2 and not skip_download: print(f"Downloading {shard.model_id=} for {inference_engine_classname}")
|
|
|
+ if DEBUG >= 6 and not skip_download: print(f"Downloading {shard.model_id=} for {inference_engine_classname}")
|
|
|
repo_id = get_repo(shard.model_id, inference_engine_classname)
|
|
|
revision = "main"
|
|
|
target_dir = await ensure_downloads_dir()/repo_id.replace("/", "--")
|