浏览代码

smart model downloading for mlx #16

Alex Cheema 1 年之前
父节点
当前提交
3a230f3b44
共有 1 个文件被更改,包括 24 次插入4 次删除
  1. 24 4
      exo/inference/mlx/sharded_utils.py

+ 24 - 4
exo/inference/mlx/sharded_utils.py

@@ -157,7 +157,22 @@ async def snapshot_download_async(*args, **kwargs):
     func = partial(snapshot_download, *args, **kwargs)
     func = partial(snapshot_download, *args, **kwargs)
     return await asyncio.get_event_loop().run_in_executor(None, func)
     return await asyncio.get_event_loop().run_in_executor(None, func)
 
 
-async def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
+model_file_to_layers = {
+    "mlx-community/Meta-Llama-3-70B-Instruct-4bit": {
+        "model-00001-of-00008.safetensors": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
+        "model-00002-of-00008.safetensors": [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
+        "model-00003-of-00008.safetensors": [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
+        "model-00004-of-00008.safetensors": [31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42],
+        "model-00005-of-00008.safetensors": [42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53],
+        "model-00006-of-00008.safetensors": [53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64],
+        "model-00007-of-00008.safetensors": [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75],
+        "model-00008-of-00008.safetensors": [75, 76, 77, 78, 79]
+    }
+}
+
+async def get_model_path(path_or_hf_repo: str, shard: Optional[Shard] = None, revision: Optional[str] = None) -> Path:
+
+
     """
     """
     Ensures the model is available locally. If the path does not exist locally,
     Ensures the model is available locally. If the path does not exist locally,
     it is downloaded from the Hugging Face Hub.
     it is downloaded from the Hugging Face Hub.
@@ -171,6 +186,12 @@ async def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -
     """
     """
     model_path = Path(path_or_hf_repo)
     model_path = Path(path_or_hf_repo)
     if not model_path.exists():
     if not model_path.exists():
+        safetensors_allow_patterns = ["*.safetensors"] if not shard or path_or_hf_repo not in model_file_to_layers else [
+            name for name, included_layers in model_file_to_layers[path_or_hf_repo].items()
+            if any(layer in range(shard.start_layer, shard.end_layer + 1) for layer in included_layers)
+        ]
+        print(f"{safetensors_allow_patterns=}")
+
         try:
         try:
             model_path = Path(
             model_path = Path(
                 await snapshot_download_async(
                 await snapshot_download_async(
@@ -178,12 +199,11 @@ async def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -
                     revision=revision,
                     revision=revision,
                     allow_patterns=[
                     allow_patterns=[
                         "*.json",
                         "*.json",
-                        "*.safetensors",
                         "*.py",
                         "*.py",
                         "tokenizer.model",
                         "tokenizer.model",
                         "*.tiktoken",
                         "*.tiktoken",
                         "*.txt",
                         "*.txt",
-                    ],
+                    ] + safetensors_allow_patterns,
                 )
                 )
             )
             )
         except RepositoryNotFoundError:
         except RepositoryNotFoundError:
@@ -226,7 +246,7 @@ async def load_shard(
         FileNotFoundError: If config file or safetensors are not found.
         FileNotFoundError: If config file or safetensors are not found.
         ValueError: If model class or args class are not found.
         ValueError: If model class or args class are not found.
     """
     """
-    model_path = await get_model_path(path_or_hf_repo)
+    model_path = await get_model_path(path_or_hf_repo, shard=shard)
 
 
     model = load_model_shard(model_path, shard, lazy, model_config)
     model = load_model_shard(model_path, shard, lazy, model_config)
     if adapter_path is not None:
     if adapter_path is not None: