Browse Source

memory-efficient shard loading

Alex Cheema 1 year ago
parent
commit
b9c323bb07
1 changed files with 5 additions and 2 deletions
  1. 5 2
      exo/inference/mlx/sharded_utils.py

+ 5 - 2
exo/inference/mlx/sharded_utils.py

@@ -107,8 +107,11 @@ def load_model_shard(
         raise FileNotFoundError(f"No safetensors found in {model_path}")
 
     weights = {}
+    all_weights_keys = set()
     for wf in weight_files:
-        weights.update(mx.load(wf))
+        weights_dict = mx.load(wf)
+        all_weights_keys.update(weights_dict.keys())
+        weights.update({k: v for k, v in weights_dict.items() if not k.startswith("model.layers.") or shard.start_layer <= int(k.split('.')[2]) <= shard.end_layer})
 
     model_class, model_args_class = _get_classes(config=config)
 
@@ -123,7 +126,7 @@ def load_model_shard(
         def class_predicate(p, m):
             if not hasattr(m, "to_quantized"):
                 return False
-            return f"{p}.scales" in weights
+            return f"{p}.scales" in all_weights_keys
 
         nn.quantize(
             model,