|
@@ -13,6 +13,7 @@ from PIL import Image
|
|
|
from io import BytesIO
|
|
|
import base64
|
|
|
import os
|
|
|
+import concurrent.futures
|
|
|
|
|
|
from exo import DEBUG
|
|
|
import mlx.core as mx
|
|
@@ -120,7 +121,18 @@ def load_model_shard(
|
|
|
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
|
|
|
|
|
weights = {}
|
|
|
- for wf in weight_files:
|
|
|
+ for wf in sorted(weight_files):
|
|
|
+ if DEBUG >= 8:
|
|
|
+ layer_nums = set()
|
|
|
+ for k in mx.load(wf):
|
|
|
+ if k.startswith("model.layers."):
|
|
|
+ layer_num = int(k.split(".")[2])
|
|
|
+ layer_nums.add(layer_num)
|
|
|
+ if k.startswith("language_model.model.layers."):
|
|
|
+ layer_num = int(k.split(".")[3])
|
|
|
+ layer_nums.add(layer_num)
|
|
|
+ print(f"\"{wf.split('/')[-1]}\": {sorted(layer_nums)},")
|
|
|
+
|
|
|
weights.update(mx.load(wf))
|
|
|
|
|
|
model_class, model_args_class = _get_classes(config=config)
|
|
@@ -150,14 +162,15 @@ def load_model_shard(
|
|
|
async def get_repo_size(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None):
|
|
|
it = await asyncio.to_thread(list_repo_tree, repo_id, revision=revision, repo_type=repo_type)
|
|
|
files = list(filter_repo_objects(it, allow_patterns=allow_patterns, key=lambda f: f.path))
|
|
|
- return sum(file.size for file in files if file.size is not None)
|
|
|
+ return sum(file.size for file in files if hasattr(file, "size") and file.size is not None)
|
|
|
|
|
|
async def monitor_progress(dir, total_size, print_progress=False, on_progress: Callable[[int, int], None] = None):
|
|
|
while True:
|
|
|
+ try:
|
|
|
await asyncio.sleep(0.1)
|
|
|
current_size = sum(os.path.getsize(os.path.join(root, file))
|
|
|
- for root, _, files in os.walk(dir)
|
|
|
- for file in files)
|
|
|
+ for root, _, files in os.walk(dir)
|
|
|
+ for file in files)
|
|
|
progress = min(current_size / total_size * 100, 100)
|
|
|
if print_progress:
|
|
|
print(f"\rProgress: {progress:.2f}% ({current_size}/{total_size} bytes)", end="", flush=True)
|
|
@@ -167,10 +180,15 @@ async def monitor_progress(dir, total_size, print_progress=False, on_progress: C
|
|
|
if print_progress:
|
|
|
print("\nDownload complete!")
|
|
|
break
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error monitoring progress: {e}")
|
|
|
|
|
|
async def download_repo(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None):
|
|
|
- # Use snapshot_download in a separate thread to not block the event loop
|
|
|
- return await asyncio.to_thread(snapshot_download, repo_id=repo_id, revision=revision, allow_patterns=allow_patterns, repo_type=repo_type)
|
|
|
+ with concurrent.futures.ThreadPoolExecutor() as pool:
|
|
|
+ return await asyncio.get_event_loop().run_in_executor(
|
|
|
+ pool,
|
|
|
+ partial(snapshot_download, repo_id=repo_id, revision=revision, allow_patterns=allow_patterns, repo_type=repo_type)
|
|
|
+ )
|
|
|
|
|
|
async def download_async_with_progress(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None, on_progress: Callable[[int, int], None] = None):
|
|
|
storage_folder = os.path.join(HF_HUB_CACHE, repo_folder_name(repo_id=repo_id, repo_type="model"))
|
|
@@ -184,11 +202,113 @@ async def download_async_with_progress(repo_id: str, revision: Optional[str] = N
|
|
|
progress_task = asyncio.create_task(monitor_progress(storage_folder, total_size, on_progress=on_progress))
|
|
|
|
|
|
# Wait for both tasks to complete
|
|
|
- result = await asyncio.gather(download_task, progress_task)
|
|
|
+ result = await asyncio.gather(download_task, progress_task, return_exceptions=True)
|
|
|
return result[0] # Return the result from download_task
|
|
|
|
|
|
+repo_id_safetensors_layers = {
|
|
|
+ "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit": {
|
|
|
+ "model.safetensors": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
|
|
|
+ },
|
|
|
+ "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit": {
|
|
|
+ "model-00001-of-00008.safetensors": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
|
|
|
+ "model-00002-of-00008.safetensors": [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
|
|
|
+ "model-00003-of-00008.safetensors": [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
|
|
|
+ "model-00004-of-00008.safetensors": [31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42],
|
|
|
+ "model-00005-of-00008.safetensors": [42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53],
|
|
|
+ "model-00006-of-00008.safetensors": [53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64],
|
|
|
+ "model-00007-of-00008.safetensors": [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75],
|
|
|
+ "model-00008-of-00008.safetensors": [75, 76, 77, 78, 79],
|
|
|
+ },
|
|
|
+ "mlx-community/Meta-Llama-3.1-405B-Instruct-4bit": {
|
|
|
+ "model-00001-of-00046.safetensors": [0, 1, 2],
|
|
|
+ "model-00002-of-00046.safetensors": [2, 3, 4, 5],
|
|
|
+ "model-00003-of-00046.safetensors": [5, 6, 7],
|
|
|
+ "model-00004-of-00046.safetensors": [8, 9, 10],
|
|
|
+ "model-00005-of-00046.safetensors": [10, 11, 12, 13],
|
|
|
+ "model-00006-of-00046.safetensors": [13, 14, 15, 16],
|
|
|
+ "model-00007-of-00046.safetensors": [16, 17, 18, 19],
|
|
|
+ "model-00008-of-00046.safetensors": [19, 20, 21],
|
|
|
+ "model-00009-of-00046.safetensors": [22, 23, 24],
|
|
|
+ "model-00010-of-00046.safetensors": [24, 25, 26, 27],
|
|
|
+ "model-00011-of-00046.safetensors": [27, 28, 29, 30],
|
|
|
+ "model-00012-of-00046.safetensors": [30, 31, 32, 33],
|
|
|
+ "model-00013-of-00046.safetensors": [33, 34, 35],
|
|
|
+ "model-00014-of-00046.safetensors": [36, 37, 38],
|
|
|
+ "model-00015-of-00046.safetensors": [38, 39, 40, 41],
|
|
|
+ "model-00016-of-00046.safetensors": [41, 42, 43, 44],
|
|
|
+ "model-00017-of-00046.safetensors": [44, 45, 46, 47],
|
|
|
+ "model-00018-of-00046.safetensors": [47, 48, 49],
|
|
|
+ "model-00019-of-00046.safetensors": [50, 51, 52],
|
|
|
+ "model-00020-of-00046.safetensors": [52, 53, 54, 55],
|
|
|
+ "model-00021-of-00046.safetensors": [55, 56, 57, 58],
|
|
|
+ "model-00022-of-00046.safetensors": [58, 59, 60, 61],
|
|
|
+ "model-00023-of-00046.safetensors": [61, 62, 63],
|
|
|
+ "model-00024-of-00046.safetensors": [64, 65, 66],
|
|
|
+ "model-00025-of-00046.safetensors": [66, 67, 68, 69],
|
|
|
+ "model-00026-of-00046.safetensors": [69, 70, 71, 72],
|
|
|
+ "model-00027-of-00046.safetensors": [72, 73, 74, 75],
|
|
|
+ "model-00028-of-00046.safetensors": [75, 76, 77],
|
|
|
+ "model-00029-of-00046.safetensors": [78, 79, 80],
|
|
|
+ "model-00030-of-00046.safetensors": [80, 81, 82, 83],
|
|
|
+ "model-00031-of-00046.safetensors": [83, 84, 85, 86],
|
|
|
+ "model-00032-of-00046.safetensors": [86, 87, 88, 89],
|
|
|
+ "model-00033-of-00046.safetensors": [89, 90, 91],
|
|
|
+ "model-00034-of-00046.safetensors": [92, 93, 94],
|
|
|
+ "model-00035-of-00046.safetensors": [94, 95, 96, 97],
|
|
|
+ "model-00036-of-00046.safetensors": [97, 98, 99, 100],
|
|
|
+ "model-00037-of-00046.safetensors": [100, 101, 102, 103],
|
|
|
+ "model-00038-of-00046.safetensors": [103, 104, 105],
|
|
|
+ "model-00039-of-00046.safetensors": [106, 107, 108],
|
|
|
+ "model-00040-of-00046.safetensors": [108, 109, 110, 111],
|
|
|
+ "model-00041-of-00046.safetensors": [111, 112, 113, 114],
|
|
|
+ "model-00042-of-00046.safetensors": [114, 115, 116, 117],
|
|
|
+ "model-00043-of-00046.safetensors": [117, 118, 119],
|
|
|
+ "model-00044-of-00046.safetensors": [120, 121, 122],
|
|
|
+ "model-00045-of-00046.safetensors": [122, 123, 124, 125],
|
|
|
+ "model-00046-of-00046.safetensors": [125]
|
|
|
+ },
|
|
|
+ "mlx-community/Mistral-Nemo-Instruct-2407-4bit": {
|
|
|
+ "model-00001-of-00002.safetensors": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32],
|
|
|
+ "model-00002-of-00002.safetensors": [32, 33, 34, 35, 36, 37, 38, 39],
|
|
|
+ },
|
|
|
+ "mlx-community/Mistral-Large-Instruct-2407-4bit": {
|
|
|
+ "model-00001-of-00014.safetensors": [0, 1, 2, 3, 4, 5, 6],
|
|
|
+ "model-00002-of-00014.safetensors": [6, 7, 8, 9, 10, 11, 12, 13],
|
|
|
+ "model-00003-of-00014.safetensors": [13, 14, 15, 16, 17, 18, 19, 20],
|
|
|
+ "model-00004-of-00014.safetensors": [20, 21, 22, 23, 24, 25, 26],
|
|
|
+ "model-00005-of-00014.safetensors": [27, 28, 29, 30, 31, 32, 33],
|
|
|
+ "model-00006-of-00014.safetensors": [33, 34, 35, 36, 37, 38, 39, 40],
|
|
|
+ "model-00007-of-00014.safetensors": [40, 41, 42, 43, 44, 45, 46, 47],
|
|
|
+ "model-00008-of-00014.safetensors": [47, 48, 49, 50, 51, 52, 53, 54],
|
|
|
+ "model-00009-of-00014.safetensors": [54, 55, 56, 57, 58, 59, 60],
|
|
|
+ "model-00010-of-00014.safetensors": [61, 62, 63, 64, 65, 66, 67],
|
|
|
+ "model-00011-of-00014.safetensors": [67, 68, 69, 70, 71, 72, 73, 74],
|
|
|
+ "model-00012-of-00014.safetensors": [74, 75, 76, 77, 78, 79, 80, 81],
|
|
|
+ "model-00013-of-00014.safetensors": [81, 82, 83, 84, 85, 86, 87],
|
|
|
+ "model-00014-of-00014.safetensors": [87]
|
|
|
+ },
|
|
|
+ "llava-hf/llava-1.5-7b-hf": {
|
|
|
+ "model-00001-of-00003.safetensors": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
|
|
|
+ "model-00002-of-00003.safetensors": [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22],
|
|
|
+ "model-00003-of-00003.safetensors": [22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+def get_safetensors_allow_patterns(repo_id: str, shard: Optional[Shard] = None):
|
|
|
+ return ["*.safetensors"] # TODO: enable this
|
|
|
+ if not shard:
|
|
|
+ return ["*.safetensors"]
|
|
|
+
|
|
|
+ allow_patterns = []
|
|
|
+ for repo_id, safetensors_layers in repo_id_safetensors_layers.items():
|
|
|
+ if repo_id == shard.model_id:
|
|
|
+ for safetensor, layers in safetensors_layers.items():
|
|
|
+ if any(shard.start_layer <= layer <= shard.end_layer for layer in layers):
|
|
|
+ allow_patterns.append(safetensor)
|
|
|
+
|
|
|
+ return allow_patterns if len(allow_patterns) > 0 else ["*.safetensors"]
|
|
|
|
|
|
-async def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None, on_download_progress: Callable[[int, int], None] = None) -> Path:
|
|
|
+async def get_model_path(path_or_hf_repo: str, shard: Optional[Shard] = None, revision: Optional[str] = None, on_download_progress: Callable[[int, int], None] = None) -> Path:
|
|
|
"""
|
|
|
Ensures the model is available locally. If the path does not exist locally,
|
|
|
it is downloaded from the Hugging Face Hub.
|
|
@@ -209,12 +329,11 @@ async def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None, o
|
|
|
revision=revision,
|
|
|
allow_patterns=[
|
|
|
"*.json",
|
|
|
- "*.safetensors",
|
|
|
"*.py",
|
|
|
"tokenizer.model",
|
|
|
"*.tiktoken",
|
|
|
"*.txt",
|
|
|
- ],
|
|
|
+ ] + get_safetensors_allow_patterns(path_or_hf_repo, shard),
|
|
|
on_progress=on_download_progress,
|
|
|
)
|
|
|
)
|
|
@@ -259,7 +378,7 @@ async def load_shard(
|
|
|
FileNotFoundError: If config file or safetensors are not found.
|
|
|
ValueError: If model class or args class are not found.
|
|
|
"""
|
|
|
- model_path = await get_model_path(path_or_hf_repo, on_download_progress=on_download_progress)
|
|
|
+ model_path = await get_model_path(path_or_hf_repo, shard, on_download_progress=on_download_progress)
|
|
|
|
|
|
model = load_model_shard(model_path, shard, lazy, model_config)
|
|
|
if adapter_path is not None:
|