|
@@ -1,62 +1,126 @@
|
|
|
from exo.inference.shard import Shard
|
|
|
+from typing import Optional
|
|
|
|
|
|
-model_base_shards = {
|
|
|
+model_cards = {
|
|
|
### llama
|
|
|
"llama-3.2-1b": {
|
|
|
- "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-1B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=16),
|
|
|
- "TinygradDynamicShardInferenceEngine": Shard(model_id="unsloth/Llama-3.2-1B-Instruct", start_layer=0, end_layer=0, n_layers=16),
|
|
|
+ "layers": 16,
|
|
|
+ "repo": {
|
|
|
+ "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-1B-Instruct-4bit",
|
|
|
+ "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-1B-Instruct",
|
|
|
+ },
|
|
|
},
|
|
|
"llama-3.2-3b": {
|
|
|
- "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-3B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),
|
|
|
- "TinygradDynamicShardInferenceEngine": Shard(model_id="unsloth/Llama-3.2-3B-Instruct", start_layer=0, end_layer=0, n_layers=28),
|
|
|
+ "layers": 28,
|
|
|
+ "repo": {
|
|
|
+ "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-3B-Instruct-4bit",
|
|
|
+ "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-3B-Instruct",
|
|
|
+ },
|
|
|
},
|
|
|
"llama-3.1-8b": {
|
|
|
- "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
|
|
|
- "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32),
|
|
|
+ "layers": 32,
|
|
|
+ "repo": {
|
|
|
+ "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
|
|
|
+ "TinygradDynamicShardInferenceEngine": "mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated",
|
|
|
+ },
|
|
|
},
|
|
|
"llama-3.1-70b": {
|
|
|
- "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
|
|
|
- "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B-Instruct", start_layer=0, end_layer=0, n_layers=80),
|
|
|
+ "layers": 80,
|
|
|
+ "repo": {
|
|
|
+ "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit",
|
|
|
+ "TinygradDynamicShardInferenceEngine": "NousResearch/Meta-Llama-3.1-70B-Instruct",
|
|
|
+ },
|
|
|
},
|
|
|
"llama-3.1-70b-bf16": {
|
|
|
- "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-bf16-CORRECTED", start_layer=0, end_layer=0, n_layers=80),
|
|
|
- "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B-Instruct", start_layer=0, end_layer=0, n_layers=80),
|
|
|
+ "layers": 80,
|
|
|
+ "repo": {
|
|
|
+ "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-70B-Instruct-bf16-CORRECTED",
|
|
|
+ "TinygradDynamicShardInferenceEngine": "NousResearch/Meta-Llama-3.1-70B-Instruct",
|
|
|
+ },
|
|
|
},
|
|
|
- "llama-3.1-405b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),},
|
|
|
- "llama-3.1-405b-8bit": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-Instruct-8bit", start_layer=0, end_layer=0, n_layers=126),},
|
|
|
"llama-3-8b": {
|
|
|
- "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
|
|
|
- "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32),
|
|
|
+ "layers": 32,
|
|
|
+ "repo": {
|
|
|
+ "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
|
|
|
+ "TinygradDynamicShardInferenceEngine": "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R",
|
|
|
+ },
|
|
|
},
|
|
|
"llama-3-70b": {
|
|
|
- "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
|
|
|
- "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80),
|
|
|
+ "layers": 80,
|
|
|
+ "repo": {
|
|
|
+ "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3-70B-Instruct-4bit",
|
|
|
+ "TinygradDynamicShardInferenceEngine": "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R",
|
|
|
+ },
|
|
|
},
|
|
|
+ "llama-3.1-405b": { "layers": 126, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-405B-4bit", }, },
|
|
|
+ "llama-3.1-405b-8bit": { "layers": 126, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit", }, },
|
|
|
### mistral
|
|
|
- "mistral-nemo": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),},
|
|
|
- "mistral-large": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),},
|
|
|
+ "mistral-nemo": { "layers": 40, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Mistral-Nemo-Instruct-2407-4bit", }, },
|
|
|
+ "mistral-large": { "layers": 88, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Mistral-Large-Instruct-2407-4bit", }, },
|
|
|
### deepseek
|
|
|
- "deepseek-coder-v2-lite": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27),},
|
|
|
- "deepseek-coder-v2.5": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", start_layer=0, end_layer=0, n_layers=60),},
|
|
|
+ "deepseek-coder-v2-lite": { "layers": 27, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", }, },
|
|
|
+ "deepseek-coder-v2.5": { "layers": 60, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", }, },
|
|
|
### llava
|
|
|
- "llava-1.5-7b-hf": {"MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32),},
|
|
|
+ "llava-1.5-7b-hf": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "llava-hf/llava-1.5-7b-hf", }, },
|
|
|
### qwen
|
|
|
- "qwen-2.5-coder-1.5b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),},
|
|
|
- "qwen-2.5-coder-3b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Coder-3B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=36),},
|
|
|
- "qwen-2.5-coder-7b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),},
|
|
|
- "qwen-2.5-coder-14b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Coder-14B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=48),},
|
|
|
- "qwen-2.5-coder-32b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Coder-32B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=64),},
|
|
|
- "qwen-2.5-7b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-7B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),},
|
|
|
- "qwen-2.5-math-7b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Math-7B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),},
|
|
|
- "qwen-2.5-14b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-14B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=48),},
|
|
|
- "qwen-2.5-72b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-72B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),},
|
|
|
- "qwen-2.5-math-72b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Math-72B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),},
|
|
|
+ "qwen-2.5-coder-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", }, },
|
|
|
+ "qwen-2.5-coder-3b": { "layers": 36, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-3B-Instruct-4bit", }, },
|
|
|
+ "qwen-2.5-coder-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", }, },
|
|
|
+ "qwen-2.5-coder-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-14B-Instruct-4bit", }, },
|
|
|
+ "qwen-2.5-coder-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-32B-Instruct-4bit", }, },
|
|
|
+ "qwen-2.5-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-7B-Instruct-4bit", }, },
|
|
|
+ "qwen-2.5-math-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-7B-Instruct-4bit", }, },
|
|
|
+ "qwen-2.5-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-14B-Instruct-4bit", }, },
|
|
|
+ "qwen-2.5-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-72B-Instruct-4bit", }, },
|
|
|
+ "qwen-2.5-math-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-72B-Instruct-4bit", }, },
|
|
|
### nemotron
|
|
|
- "nemotron-70b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/nvidia_Llama-3.1-Nemotron-70B-Instruct-HF_4bit", start_layer=0, end_layer=0, n_layers=80),},
|
|
|
- "nemotron-70b-bf16": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.1-Nemotron-70B-Instruct-HF-bf16", start_layer=0, end_layer=0, n_layers=80),},
|
|
|
+ "nemotron-70b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/nvidia_Llama-3.1-Nemotron-70B-Instruct-HF_4bit", }, },
|
|
|
+ "nemotron-70b-bf16": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.1-Nemotron-70B-Instruct-HF-bf16", }, },
|
|
|
# gemma
|
|
|
- "gemma2-9b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/gemma-2-9b-it-4bit", start_layer=0, end_layer=0, n_layers=42),},
|
|
|
- "gemma2-27b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/gemma-2-27b-it-4bit", start_layer=0, end_layer=0, n_layers=46),},
|
|
|
+ "gemma2-9b": { "layers": 42, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-9b-it-4bit", }, },
|
|
|
+ "gemma2-27b": { "layers": 46, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-27b-it-4bit", }, },
|
|
|
# dummy
|
|
|
- "dummy": {"DummyInferenceEngine": Shard(model_id="dummy", start_layer=0, end_layer=7, n_layers=8),},
|
|
|
+ "dummy": { "layers": 8, "repo": { "DummyInferenceEngine": "dummy", }, },
|
|
|
}
|
|
|
+
|
|
|
+pretty_name = {
|
|
|
+ "llama-3.2-1b": "Llama 3.2 1B",
|
|
|
+ "llama-3.2-3b": "Llama 3.2 3B",
|
|
|
+ "llama-3.1-8b": "Llama 3.1 8B",
|
|
|
+ "llama-3.1-70b": "Llama 3.1 70B",
|
|
|
+ "llama-3.1-70b-bf16": "Llama 3.1 70B (BF16)",
|
|
|
+ "llama-3.1-405b": "Llama 3.1 405B",
|
|
|
+ "llama-3.1-405b-8bit": "Llama 3.1 405B (8-bit)",
|
|
|
+ "gemma2-9b": "Gemma2 9B",
|
|
|
+ "gemma2-27b": "Gemma2 27B",
|
|
|
+ "nemotron-70b": "Nemotron 70B",
|
|
|
+ "nemotron-70b-bf16": "Nemotron 70B (BF16)",
|
|
|
+ "mistral-nemo": "Mistral Nemo",
|
|
|
+ "mistral-large": "Mistral Large",
|
|
|
+ "deepseek-coder-v2-lite": "Deepseek Coder V2 Lite",
|
|
|
+ "deepseek-coder-v2.5": "Deepseek Coder V2.5",
|
|
|
+ "llava-1.5-7b-hf": "LLaVa 1.5 7B (Vision Model)",
|
|
|
+ "qwen-2.5-coder-1.5b": "Qwen 2.5 Coder 1.5B",
|
|
|
+ "qwen-2.5-coder-3b": "Qwen 2.5 Coder 3B",
|
|
|
+ "qwen-2.5-coder-7b": "Qwen 2.5 Coder 7B",
|
|
|
+ "qwen-2.5-coder-14b": "Qwen 2.5 Coder 14B",
|
|
|
+ "qwen-2.5-coder-32b": "Qwen 2.5 Coder 32B",
|
|
|
+ "qwen-2.5-7b": "Qwen 2.5 7B",
|
|
|
+ "qwen-2.5-math-7b": "Qwen 2.5 7B (Math)",
|
|
|
+ "qwen-2.5-14b": "Qwen 2.5 14B",
|
|
|
+ "qwen-2.5-72b": "Qwen 2.5 72B",
|
|
|
+ "qwen-2.5-math-72b": "Qwen 2.5 72B (Math)",
|
|
|
+ "llama-3-8b": "Llama 3 8B",
|
|
|
+ "llama-3-70b": "Llama 3 70B",
|
|
|
+}
|
|
|
+
|
|
|
+def get_repo(model_id: str, inference_engine_classname: str) -> Optional[str]:
|
|
|
+ return model_cards.get(model_id, {}).get("repo", {}).get(inference_engine_classname, None)
|
|
|
+
|
|
|
+def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional[Shard]:
|
|
|
+ repo = get_repo(model_id, inference_engine_classname)
|
|
|
+ n_layers = model_cards.get(model_id, {}).get("layers", 0)
|
|
|
+ if repo is None or n_layers < 1:
|
|
|
+ return None
|
|
|
+ return Shard(model_id, 0, 0, n_layers)
|
|
|
+
|