|
@@ -89,19 +89,6 @@ class Model(nn.Module):
|
|
|
return out
|
|
|
|
|
|
def sanitize(self, weights):
|
|
|
- for l in range(self.args.num_hidden_layers):
|
|
|
- prefix = f"model.layers.{l}"
|
|
|
- for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
|
|
|
- for k in ["weight", "scales", "biases"]:
|
|
|
- if f"{prefix}.mlp.experts.0.{m}.{k}" in weights:
|
|
|
- to_join = [
|
|
|
- weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}") for e in range(self.args.n_routed_experts)
|
|
|
- ]
|
|
|
- weights[
|
|
|
- f"{prefix}.mlp.switch_mlp.{
|
|
|
- m}.{k}"
|
|
|
- ] = mx.stack(to_join)
|
|
|
-
|
|
|
shard_state_dict = {}
|
|
|
|
|
|
for key, value in weights.items():
|
|
@@ -113,6 +100,21 @@ class Model(nn.Module):
|
|
|
shard_state_dict[key] = value
|
|
|
elif self.args.shard.is_last_layer() and (key.startswith('model.norm') or key.startswith('lm_head')):
|
|
|
shard_state_dict[key] = value
|
|
|
+
|
|
|
+ for l in range(self.args.num_hidden_layers):
|
|
|
+ prefix = f"model.layers.{l}"
|
|
|
+ for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
|
|
|
+ for k in ["weight", "scales", "biases"]:
|
|
|
+ if f"{prefix}.mlp.experts.0.{m}.{k}" in shard_state_dict:
|
|
|
+ to_join = [
|
|
|
+ shard_state_dict.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}") for e in range(self.args.n_routed_experts)
|
|
|
+ ]
|
|
|
+ shard_state_dict[
|
|
|
+ f"{prefix}.mlp.switch_mlp.{
|
|
|
+ m}.{k}"
|
|
|
+ ] = mx.stack(to_join)
|
|
|
+
|
|
|
+
|
|
|
return shard_state_dict
|
|
|
|
|
|
@property
|