浏览代码

fix legacy model loading

Alex Cheema 1 年之前
父节点
当前提交
eafed8e16e
共有 3 个文件被更改,包括 12 次插入3 次删除
  1. 5 1
      exo/inference/mlx/models/llama.py
  2. 6 1
      exo/inference/mlx/sharded_utils.py
  3. 1 1
      exo/inference/test_inference_engine.py

+ 5 - 1
exo/inference/mlx/models/llama.py

@@ -106,7 +106,11 @@ class Model(nn.Module):
           shard_state_dict[key] = value
       elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'):
         shard_state_dict[key] = value
-      elif self.args.shard.is_last_layer() and (key.startswith('model.norm') or key.startswith('lm_head')):
+      elif (self.args.shard.is_last_layer() and self.args.tie_word_embeddings) and key.startswith('model.embed_tokens'):
+        shard_state_dict[key] = value
+      elif (self.args.shard.is_last_layer() and not self.args.tie_word_embeddings) and key.startswith('lm_head'):
+        shard_state_dict[key] = value
+      elif self.args.shard.is_last_layer() and (key.startswith('model.norm')):
         shard_state_dict[key] = value
 
     return shard_state_dict

+ 6 - 1
exo/inference/mlx/sharded_utils.py

@@ -144,10 +144,15 @@ def load_model_shard(
     weights = model.sanitize(weights)
 
   if (quantization := config.get("quantization", None)) is not None:
+    # Handle legacy models which may not have everything quantized
+    def class_predicate(p, m):
+        if not hasattr(m, "to_quantized"):
+            return False
+        return f"{p}.scales" in weights
     nn.quantize(
       model,
       **quantization,
-      class_predicate=None,
+      class_predicate=class_predicate,
     )
 
   model.load_weights(list(weights.items()), strict=True)

+ 1 - 1
exo/inference/test_inference_engine.py

@@ -44,7 +44,7 @@ asyncio.run(
   test_inference_engine(
     MLXDynamicShardInferenceEngine(),
     MLXDynamicShardInferenceEngine(),
-    "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
+    "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
   )
 )