|
@@ -13,10 +13,7 @@ from exo.inference.shard import Shard
|
|
from exo.orchestration import Node
|
|
from exo.orchestration import Node
|
|
|
|
|
|
shard_mappings = {
|
|
shard_mappings = {
|
|
- "llama-3-8b": {
|
|
|
|
- "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
|
|
|
|
- "TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-8b-sfr", start_layer=0, end_layer=0, n_layers=32),
|
|
|
|
- },
|
|
|
|
|
|
+ # llama
|
|
"llama-3.1-8b": {
|
|
"llama-3.1-8b": {
|
|
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
|
|
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
|
|
},
|
|
},
|
|
@@ -26,10 +23,21 @@ shard_mappings = {
|
|
"llama-3.1-405b": {
|
|
"llama-3.1-405b": {
|
|
"MLXDynamicShardInferenceEngine": Shard(model_id="/Users/alex/405b-instruct-4bit", start_layer=0, end_layer=0, n_layers=126),
|
|
"MLXDynamicShardInferenceEngine": Shard(model_id="/Users/alex/405b-instruct-4bit", start_layer=0, end_layer=0, n_layers=126),
|
|
},
|
|
},
|
|
|
|
+ "llama-3-8b": {
|
|
|
|
+ "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
|
|
|
|
+ "TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-8b-sfr", start_layer=0, end_layer=0, n_layers=32),
|
|
|
|
+ },
|
|
"llama-3-70b": {
|
|
"llama-3-70b": {
|
|
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
|
|
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
|
|
"TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-70b-sfr", start_layer=0, end_layer=0, n_layers=80),
|
|
"TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-70b-sfr", start_layer=0, end_layer=0, n_layers=80),
|
|
},
|
|
},
|
|
|
|
+ # mistral
|
|
|
|
+ "mistral-nemo": {
|
|
|
|
+ "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),
|
|
|
|
+ },
|
|
|
|
+ "mistral-large": {
|
|
|
|
+ "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),
|
|
|
|
+ },
|
|
}
|
|
}
|
|
|
|
|
|
class Message:
|
|
class Message:
|