|
|
@@ -1,6 +1,11 @@
|
|
|
from exo.inference.shard import Shard
|
|
|
|
|
|
model_base_shards = {
|
|
|
+ # smollm
|
|
|
+ "smollm2-135m": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/SmolLM2-135M-Instruct", start_layer=0, end_layer=0, n_layers=30),},
|
|
|
+ "smollm2-360m": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/SmolLM2-360M-Instruct", start_layer=0, end_layer=0, n_layers=32),},
|
|
|
+ "smollm2-1.7b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/SmolLM2-1.7B-Instruct", start_layer=0, end_layer=0, n_layers=24),},
|
|
|
+
|
|
|
### llama
|
|
|
"llama-3.2-1b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-1B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=16),},
|
|
|
"llama-3.2-3b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-3B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),},
|