Browse Source

fix llava sanitize

Alex Cheema 9 months ago
parent
commit
33cbacf513
1 changed files with 26 additions and 12 deletions
  1. 26 12
      exo/inference/mlx/models/llava.py

+ 26 - 12
exo/inference/mlx/models/llava.py

@@ -208,8 +208,7 @@ class VisionModel(nn.Module):
     ) -> mx.array:
         return self.vision_model(x, output_hidden_states)
 
-    @staticmethod
-    def sanitize(weights):
+    def sanitize(self, weights):
         sanitized_weights = {}
         for k, v in weights.items():
             if "position_ids" in k:
@@ -380,7 +379,8 @@ class Llama(nn.Module):
         self.num_key_value_heads = config.num_key_value_heads
         self.head_dim = config.head_dim
         assert self.vocab_size > 0
-        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
+        if self.shard.is_first_layer():
+            self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
         self.layers = []
         for i in range(self.num_hidden_layers):
           if self.shard.start_layer <= i <= self.shard.end_layer:
@@ -431,7 +431,8 @@ class LanguageModel(nn.Module):
             )
         self.shard = shard
         self.model = Llama(config, shard)
-        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+        if self.shard.is_last_layer():
+            self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
     def __call__(
         self,
@@ -444,12 +445,24 @@ class LanguageModel(nn.Module):
             out = self.lm_head(out)
         return out
 
-    @staticmethod
-    def sanitize(weights):
-        # Remove unused precomputed rotary freqs
-        return {
-            k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
-        }
+    def sanitize(self, weights):
+        shard_state_dict = {}
+        for key, value in weights.items():
+            if "self_attn.rotary_emb.inv_freq" in key:
+                continue
+
+            if key.startswith('language_model.model.layers.'):
+                layer_num = int(key.split('.')[3])
+                if layer_num < self.shard.start_layer or layer_num > self.shard.end_layer:
+                    continue
+            if not self.shard.is_first_layer() and key.startswith('language_model.model.embed_tokens'):
+                continue
+            elif not self.shard.is_last_layer() and (key.startswith('language_model.model.norm') or key.startswith('language_model.lm_head')):
+                continue
+
+            shard_state_dict[key] = value
+
+        return shard_state_dict
 
 @dataclass
 class LlaVAConfig(BaseModelArgs):
@@ -599,9 +612,10 @@ class Model(nn.Module):
 
     def sanitize(self, weights):
         if self.config.vision_config:
-            weights = self.vision_tower.sanitize(weights)
+          weights = self.vision_tower.sanitize(weights)
+        else:
+          weights = {k: v for k, v in weights.items() if not k.startswith(('vision_tower', 'multi_modal_projector', 'vision_feature_layer', 'vision_feature_select_strategy'))}
         weights = self.language_model.sanitize(weights)
-
         return weights
 
     @property