|
@@ -208,8 +208,7 @@ class VisionModel(nn.Module):
|
|
) -> mx.array:
|
|
) -> mx.array:
|
|
return self.vision_model(x, output_hidden_states)
|
|
return self.vision_model(x, output_hidden_states)
|
|
|
|
|
|
- @staticmethod
|
|
|
|
- def sanitize(weights):
|
|
|
|
|
|
+ def sanitize(self, weights):
|
|
sanitized_weights = {}
|
|
sanitized_weights = {}
|
|
for k, v in weights.items():
|
|
for k, v in weights.items():
|
|
if "position_ids" in k:
|
|
if "position_ids" in k:
|
|
@@ -380,7 +379,8 @@ class Llama(nn.Module):
|
|
self.num_key_value_heads = config.num_key_value_heads
|
|
self.num_key_value_heads = config.num_key_value_heads
|
|
self.head_dim = config.head_dim
|
|
self.head_dim = config.head_dim
|
|
assert self.vocab_size > 0
|
|
assert self.vocab_size > 0
|
|
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
|
|
|
|
|
+ if self.shard.is_first_layer():
|
|
|
|
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
|
self.layers = []
|
|
self.layers = []
|
|
for i in range(self.num_hidden_layers):
|
|
for i in range(self.num_hidden_layers):
|
|
if self.shard.start_layer <= i <= self.shard.end_layer:
|
|
if self.shard.start_layer <= i <= self.shard.end_layer:
|
|
@@ -431,7 +431,8 @@ class LanguageModel(nn.Module):
|
|
)
|
|
)
|
|
self.shard = shard
|
|
self.shard = shard
|
|
self.model = Llama(config, shard)
|
|
self.model = Llama(config, shard)
|
|
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
|
|
|
+ if self.shard.is_last_layer():
|
|
|
|
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
|
|
|
def __call__(
|
|
def __call__(
|
|
self,
|
|
self,
|
|
@@ -444,12 +445,24 @@ class LanguageModel(nn.Module):
|
|
out = self.lm_head(out)
|
|
out = self.lm_head(out)
|
|
return out
|
|
return out
|
|
|
|
|
|
- @staticmethod
|
|
|
|
- def sanitize(weights):
|
|
|
|
- # Remove unused precomputed rotary freqs
|
|
|
|
- return {
|
|
|
|
- k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
|
|
|
- }
|
|
|
|
|
|
+ def sanitize(self, weights):
|
|
|
|
+ shard_state_dict = {}
|
|
|
|
+ for key, value in weights.items():
|
|
|
|
+ if "self_attn.rotary_emb.inv_freq" in key:
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ if key.startswith('language_model.model.layers.'):
|
|
|
|
+ layer_num = int(key.split('.')[3])
|
|
|
|
+ if layer_num < self.shard.start_layer or layer_num > self.shard.end_layer:
|
|
|
|
+ continue
|
|
|
|
+ if not self.shard.is_first_layer() and key.startswith('language_model.model.embed_tokens'):
|
|
|
|
+ continue
|
|
|
|
+ elif not self.shard.is_last_layer() and (key.startswith('language_model.model.norm') or key.startswith('language_model.lm_head')):
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ shard_state_dict[key] = value
|
|
|
|
+
|
|
|
|
+ return shard_state_dict
|
|
|
|
|
|
@dataclass
|
|
@dataclass
|
|
class LlaVAConfig(BaseModelArgs):
|
|
class LlaVAConfig(BaseModelArgs):
|
|
@@ -599,9 +612,10 @@ class Model(nn.Module):
|
|
|
|
|
|
def sanitize(self, weights):
|
|
def sanitize(self, weights):
|
|
if self.config.vision_config:
|
|
if self.config.vision_config:
|
|
- weights = self.vision_tower.sanitize(weights)
|
|
|
|
|
|
+ weights = self.vision_tower.sanitize(weights)
|
|
|
|
+ else:
|
|
|
|
+ weights = {k: v for k, v in weights.items() if not k.startswith(('vision_tower', 'multi_modal_projector', 'vision_feature_layer', 'vision_feature_select_strategy'))}
|
|
weights = self.language_model.sanitize(weights)
|
|
weights = self.language_model.sanitize(weights)
|
|
-
|
|
|
|
return weights
|
|
return weights
|
|
|
|
|
|
@property
|
|
@property
|