Browse Source

update deepseek sanitize to shard layers first before handle switch

Anchen 1 year ago
parent
commit
a6bb8ddf41
1 changed files with 15 additions and 13 deletions
  1. 15 13
      exo/inference/mlx/models/deepseek_v2.py

+ 15 - 13
exo/inference/mlx/models/deepseek_v2.py

@@ -89,19 +89,6 @@ class Model(nn.Module):
     return out
     return out
 
 
   def sanitize(self, weights):
   def sanitize(self, weights):
-    for l in range(self.args.num_hidden_layers):
-      prefix = f"model.layers.{l}"
-      for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
-        for k in ["weight", "scales", "biases"]:
-          if f"{prefix}.mlp.experts.0.{m}.{k}" in weights:
-            to_join = [
-              weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}") for e in range(self.args.n_routed_experts)
-            ]
-            weights[
-              f"{prefix}.mlp.switch_mlp.{
-              m}.{k}"
-            ] = mx.stack(to_join)
-
     shard_state_dict = {}
     shard_state_dict = {}
 
 
     for key, value in weights.items():
     for key, value in weights.items():
@@ -113,6 +100,21 @@ class Model(nn.Module):
         shard_state_dict[key] = value
         shard_state_dict[key] = value
       elif self.args.shard.is_last_layer() and (key.startswith('model.norm') or key.startswith('lm_head')):
       elif self.args.shard.is_last_layer() and (key.startswith('model.norm') or key.startswith('lm_head')):
         shard_state_dict[key] = value
         shard_state_dict[key] = value
+
+    for l in range(self.args.num_hidden_layers):
+      prefix = f"model.layers.{l}"
+      for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
+        for k in ["weight", "scales", "biases"]:
+          if f"{prefix}.mlp.experts.0.{m}.{k}" in shard_state_dict:
+            to_join = [
+              shard_state_dict.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}") for e in range(self.args.n_routed_experts)
+            ]
+            shard_state_dict[
+              f"{prefix}.mlp.switch_mlp.{
+              m}.{k}"
+            ] = mx.stack(to_join)
+
+
     return shard_state_dict
     return shard_state_dict
 
 
   @property
   @property