|
@@ -8,16 +8,19 @@ import asyncio
|
|
import aiohttp
|
|
import aiohttp
|
|
from functools import partial
|
|
from functools import partial
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
-from typing import Optional, Tuple
|
|
|
|
-import requests
|
|
|
|
|
|
+from typing import Optional, Tuple, Union, List, Callable
|
|
from PIL import Image
|
|
from PIL import Image
|
|
from io import BytesIO
|
|
from io import BytesIO
|
|
import base64
|
|
import base64
|
|
|
|
+import os
|
|
|
|
|
|
from exo import DEBUG
|
|
from exo import DEBUG
|
|
import mlx.core as mx
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
import mlx.nn as nn
|
|
-from huggingface_hub import snapshot_download
|
|
|
|
|
|
+from huggingface_hub import snapshot_download, list_repo_tree, get_paths_info
|
|
|
|
+from huggingface_hub.utils import filter_repo_objects
|
|
|
|
+from huggingface_hub.file_download import repo_folder_name
|
|
|
|
+from huggingface_hub.constants import HF_HUB_CACHE
|
|
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
|
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
|
from transformers import AutoProcessor
|
|
from transformers import AutoProcessor
|
|
|
|
|
|
@@ -144,12 +147,50 @@ def load_model_shard(
|
|
return model
|
|
return model
|
|
|
|
|
|
|
|
|
|
-async def snapshot_download_async(*args, **kwargs):
|
|
|
|
- func = partial(snapshot_download, *args, **kwargs)
|
|
|
|
- return await asyncio.get_event_loop().run_in_executor(None, func)
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-async def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
|
|
|
|
|
|
+async def get_repo_size(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None):
|
|
|
|
+ it = await asyncio.to_thread(list_repo_tree, repo_id, revision=revision, repo_type=repo_type)
|
|
|
|
+ files = list(filter_repo_objects(it, allow_patterns=allow_patterns, key=lambda f: f.path))
|
|
|
|
+ return sum(file.size for file in files if file.size is not None)
|
|
|
|
+
|
|
|
|
+async def monitor_progress(dir, total_size, print_progress=False, on_progress: Callable[[int, int], None] = None):
|
|
|
|
+ while True:
|
|
|
|
+ await asyncio.sleep(0.1)
|
|
|
|
+ current_size = sum(os.path.getsize(os.path.join(root, file))
|
|
|
|
+ for root, _, files in os.walk(dir)
|
|
|
|
+ for file in files)
|
|
|
|
+ progress = min(current_size / total_size * 100, 100)
|
|
|
|
+ if print_progress:
|
|
|
|
+ print(f"\rProgress: {progress:.2f}% ({current_size}/{total_size} bytes)", end="", flush=True)
|
|
|
|
+ if on_progress:
|
|
|
|
+ on_progress(current_size, total_size)
|
|
|
|
+ if progress >= 100:
|
|
|
|
+ if print_progress:
|
|
|
|
+ print("\nDownload complete!")
|
|
|
|
+ break
|
|
|
|
+
|
|
|
|
+async def download_repo(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None):
|
|
|
|
+ # Use snapshot_download in a separate thread to not block the event loop
|
|
|
|
+ return await asyncio.to_thread(snapshot_download, repo_id=repo_id, revision=revision, allow_patterns=allow_patterns, repo_type=repo_type)
|
|
|
|
+
|
|
|
|
+async def download_async_with_progress(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None, on_progress: Callable[[int, int], None] = None):
|
|
|
|
+ storage_folder = os.path.join(HF_HUB_CACHE, repo_folder_name(repo_id=repo_id, repo_type="model"))
|
|
|
|
+ # os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'
|
|
|
|
+ # os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
|
|
|
|
+
|
|
|
|
+ print(f"Estimating size of repository: {repo_id}")
|
|
|
|
+ total_size = await get_repo_size(repo_id)
|
|
|
|
+ print(f"Estimated total size: {total_size} bytes")
|
|
|
|
+
|
|
|
|
+ # Create tasks for download and progress checking
|
|
|
|
+ download_task = asyncio.create_task(download_repo(repo_id, revision=revision, allow_patterns=allow_patterns, repo_type=repo_type))
|
|
|
|
+ progress_task = asyncio.create_task(monitor_progress(storage_folder, total_size, on_progress=on_progress))
|
|
|
|
+
|
|
|
|
+ # Wait for both tasks to complete
|
|
|
|
+ result = await asyncio.gather(download_task, progress_task)
|
|
|
|
+ return result[0] # Return the result from download_task
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+async def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None, on_download_progress: Callable[[int, int], None] = None) -> Path:
|
|
"""
|
|
"""
|
|
Ensures the model is available locally. If the path does not exist locally,
|
|
Ensures the model is available locally. If the path does not exist locally,
|
|
it is downloaded from the Hugging Face Hub.
|
|
it is downloaded from the Hugging Face Hub.
|
|
@@ -165,7 +206,7 @@ async def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -
|
|
if not model_path.exists():
|
|
if not model_path.exists():
|
|
try:
|
|
try:
|
|
model_path = Path(
|
|
model_path = Path(
|
|
- await snapshot_download_async(
|
|
|
|
|
|
+ await download_async_with_progress(
|
|
repo_id=path_or_hf_repo,
|
|
repo_id=path_or_hf_repo,
|
|
revision=revision,
|
|
revision=revision,
|
|
allow_patterns=[
|
|
allow_patterns=[
|
|
@@ -176,6 +217,7 @@ async def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -
|
|
"*.tiktoken",
|
|
"*.tiktoken",
|
|
"*.txt",
|
|
"*.txt",
|
|
],
|
|
],
|
|
|
|
+ on_progress=on_download_progress,
|
|
)
|
|
)
|
|
)
|
|
)
|
|
except RepositoryNotFoundError:
|
|
except RepositoryNotFoundError:
|
|
@@ -196,6 +238,7 @@ async def load_shard(
|
|
model_config={},
|
|
model_config={},
|
|
adapter_path: Optional[str] = None,
|
|
adapter_path: Optional[str] = None,
|
|
lazy: bool = False,
|
|
lazy: bool = False,
|
|
|
|
+ on_download_progress: Callable[[int, int], None] = None,
|
|
) -> Tuple[nn.Module, TokenizerWrapper]:
|
|
) -> Tuple[nn.Module, TokenizerWrapper]:
|
|
"""
|
|
"""
|
|
Load the model and tokenizer from a given path or a huggingface repository.
|
|
Load the model and tokenizer from a given path or a huggingface repository.
|
|
@@ -218,7 +261,7 @@ async def load_shard(
|
|
FileNotFoundError: If config file or safetensors are not found.
|
|
FileNotFoundError: If config file or safetensors are not found.
|
|
ValueError: If model class or args class are not found.
|
|
ValueError: If model class or args class are not found.
|
|
"""
|
|
"""
|
|
- model_path = await get_model_path(path_or_hf_repo)
|
|
|
|
|
|
+ model_path = await get_model_path(path_or_hf_repo, on_download_progress=on_download_progress)
|
|
|
|
|
|
model = load_model_shard(model_path, shard, lazy, model_config)
|
|
model = load_model_shard(model_path, shard, lazy, model_config)
|
|
if adapter_path is not None:
|
|
if adapter_path is not None:
|