|
@@ -107,8 +107,11 @@ def load_model_shard(
|
|
|
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
|
|
|
|
|
weights = {}
|
|
|
+ all_weights_keys = set()
|
|
|
for wf in weight_files:
|
|
|
- weights.update(mx.load(wf))
|
|
|
+ weights_dict = mx.load(wf)
|
|
|
+ all_weights_keys.update(weights_dict.keys())
|
|
|
+ weights.update({k: v for k, v in weights_dict.items() if not k.startswith("model.layers.") or shard.start_layer <= int(k.split('.')[2]) <= shard.end_layer})
|
|
|
|
|
|
model_class, model_args_class = _get_classes(config=config)
|
|
|
|
|
@@ -123,7 +126,7 @@ def load_model_shard(
|
|
|
def class_predicate(p, m):
|
|
|
if not hasattr(m, "to_quantized"):
|
|
|
return False
|
|
|
- return f"{p}.scales" in weights
|
|
|
+ return f"{p}.scales" in all_weights_keys
|
|
|
|
|
|
nn.quantize(
|
|
|
model,
|