Sfoglia il codice sorgente

add support for mistral nemo and mistral large

Alex Cheema 9 mesi fa
parent
commit
dd8c5d63a9

+ 12 - 4
exo/api/chatgpt_api.py

@@ -13,10 +13,7 @@ from exo.inference.shard import Shard
 from exo.orchestration import Node
 
 shard_mappings = {
-    "llama-3-8b": {
-        "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-        "TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-8b-sfr", start_layer=0, end_layer=0, n_layers=32),
-    },
+    # llama
     "llama-3.1-8b": {
         "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
     },
@@ -26,10 +23,21 @@ shard_mappings = {
     "llama-3.1-405b": {
         "MLXDynamicShardInferenceEngine": Shard(model_id="/Users/alex/405b-instruct-4bit", start_layer=0, end_layer=0, n_layers=126),
     },
+    "llama-3-8b": {
+        "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
+        "TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-8b-sfr", start_layer=0, end_layer=0, n_layers=32),
+    },
     "llama-3-70b": {
         "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
         "TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-70b-sfr", start_layer=0, end_layer=0, n_layers=80),
     },
+    # mistral
+    "mistral-nemo": {
+        "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),
+    },
+    "mistral-large": {
+        "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),
+    },
 }
 
 class Message:

+ 3 - 9
exo/inference/mlx/sharded_utils.py

@@ -25,8 +25,8 @@ class ModelNotFoundError(Exception):
         super().__init__(self.message)
 
 MODEL_REMAPPING = {
-    "mistral": "llama",  # mistral is compatible with llama
-    "phi-msft": "phixtral",
+    "sharded_mistral": "sharded_llama",  # mistral is compatible with llama
+    "sharded_phi-msft": "sharded_phixtral",
 }
 
 def _get_classes(config: dict):
@@ -122,16 +122,10 @@ def load_model_shard(
         weights = model.sanitize(weights)
 
     if (quantization := config.get("quantization", None)) is not None:
-        # Handle legacy models which may not have everything quantized
-        def class_predicate(p, m):
-            if not hasattr(m, "to_quantized"):
-                return False
-            return f"{p}.scales" in all_weights_keys
-
         nn.quantize(
             model,
             **quantization,
-            class_predicate=class_predicate,
+            class_predicate=None,
         )
 
     filtered_weights = {}

+ 2 - 0
tinychat/examples/tinychat/index.html

@@ -62,6 +62,8 @@
         <option value="llama-3.1-405b">Llama 3.1 405B</option>
         <option value="llama-3-8b">Llama 3 8B</option>
         <option value="llama-3-70b">Llama 3 70B</option>
+        <option value="mistral-nemo">Mistral Nemo</option>
+        <option value="mistral-large">Mistral Large</option>
       </select>
     </div>
     <div class="home centered" x-show="home === 0" x-transition x-effect="