|
@@ -2,6 +2,12 @@ from exo.inference.shard import Shard
|
|
|
|
|
|
model_base_shards = {
|
|
|
### llama
|
|
|
+ "llama-3.2-1b": {
|
|
|
+ "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-1B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=16),
|
|
|
+ },
|
|
|
+ "llama-3.2-3b": {
|
|
|
+ "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-3B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),
|
|
|
+ },
|
|
|
"llama-3.1-8b": {
|
|
|
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
|
|
|
"TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32),
|