Browse Source

stick to same convention as new llama

Alex Cheema 11 months ago
parent
commit
2fb961fccd
2 changed files with 19 additions and 17 deletions
  1. 18 16
      exo/inference/mlx/models/llava.py
  2. 1 1
      exo/inference/mlx/sharded_utils.py

+ 18 - 16
exo/inference/mlx/models/llava.py

@@ -9,6 +9,7 @@ import mlx.core as mx
 import mlx.nn as nn
 from mlx_lm.models.base import BaseModelArgs, KVCache
 from exo.inference.shard import Shard
+from .base import IdentityBlock
 import numpy as np
 
 
@@ -369,11 +370,10 @@ class TransformerBlock(nn.Module):
 
 
 class Llama(nn.Module):
-    def __init__(self, config: TextConfig, is_first_layer, is_last_layer):
+    def __init__(self, config: TextConfig, shard: Shard):
         super().__init__()
         self.config = config
-        self.is_first_layer = is_first_layer
-        self.is_last_layer = is_last_layer
+        self.shard = shard
         self.vocab_size = config.vocab_size
         self.model_type = config.model_type
         self.num_hidden_layers = config.num_hidden_layers
@@ -381,10 +381,14 @@ class Llama(nn.Module):
         self.head_dim = config.head_dim
         assert self.vocab_size > 0
         self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
-        self.layers = [
-            TransformerBlock(config=config) for _ in range(config.num_hidden_layers)
-        ]
-        self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+        self.layers = []
+        for i in range(self.num_hidden_layers):
+          if self.shard.start_layer <= i <= self.shard.end_layer:
+            self.layers.append(TransformerBlock(config=config))
+          else:
+            self.layers.append(IdentityBlock())
+        if self.shard.is_last_layer():
+            self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
     def __call__(
             self,
@@ -394,7 +398,7 @@ class Llama(nn.Module):
     ):
         # for passing merged input embeddings
         if inputs_embeds is None:
-            if self.is_first_layer:
+            if self.shard.is_first_layer():
                 h = self.embed_tokens(inputs)
             else:
                 h = inputs
@@ -413,20 +417,20 @@ class Llama(nn.Module):
         for layer, c in zip(self.layers, cache):
             h = layer(h, mask, c)
 
-        if self.is_last_layer:
+        if self.shard.is_last_layer():
             h = self.norm(h)
         return h
 
 class LanguageModel(nn.Module):
-    def __init__(self, config: TextConfig, is_first_layer, is_last_layer):
+    def __init__(self, config: TextConfig, shard: Shard):
         super().__init__()
         self.model_type = config.model_type
         if self.model_type != "llama":
             raise ValueError(
                 f"Model type {self.model_type} not supported. Currently only 'llama' is supported"
             )
-        self.is_last_layer = is_last_layer
-        self.model = Llama(config, is_first_layer, is_last_layer)
+        self.shard = shard
+        self.model = Llama(config, shard)
         self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
     def __call__(
@@ -436,7 +440,7 @@ class LanguageModel(nn.Module):
         inputs_embeds=None,
     ):
         out = self.model(inputs, cache, inputs_embeds)
-        if self.is_last_layer:
+        if self.shard.is_last_layer():
             out = self.lm_head(out)
         return out
 
@@ -485,8 +489,6 @@ class ModelArgs(LlaVAConfig):
         if not self.shard.is_first_layer():
             self.vision_config = None
 
-        self.text_config.num_hidden_layers = self.shard.get_layer_count()
-
 
 class LlavaMultiModalProjector(nn.Module):
     def __init__(self, config: LlaVAConfig):
@@ -516,7 +518,7 @@ class Model(nn.Module):
             self.multi_modal_projector = LlavaMultiModalProjector(config)
             self.vision_feature_layer = config.vision_feature_layer
             self.vision_feature_select_strategy = config.vision_feature_select_strategy
-        self.language_model = LanguageModel(config.text_config, config.shard.is_first_layer(), config.shard.is_last_layer())
+        self.language_model = LanguageModel(config.text_config, config.shard)
 
     def get_input_embeddings(
             self,

+ 1 - 1
exo/inference/mlx/sharded_utils.py

@@ -129,7 +129,7 @@ def load_model_shard(
       class_predicate=None,
     )
 
-  model.load_weights(list(weights.items()))
+  model.load_weights(list(weights.items()), strict=True)
 
   if not lazy:
     mx.eval(model.parameters())