|
@@ -9,6 +9,7 @@ import mlx.core as mx
|
|
|
import mlx.nn as nn
|
|
|
from mlx_lm.models.base import BaseModelArgs, KVCache
|
|
|
from exo.inference.shard import Shard
|
|
|
+from .base import IdentityBlock
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
@@ -369,11 +370,10 @@ class TransformerBlock(nn.Module):
|
|
|
|
|
|
|
|
|
class Llama(nn.Module):
|
|
|
- def __init__(self, config: TextConfig, is_first_layer, is_last_layer):
|
|
|
+ def __init__(self, config: TextConfig, shard: Shard):
|
|
|
super().__init__()
|
|
|
self.config = config
|
|
|
- self.is_first_layer = is_first_layer
|
|
|
- self.is_last_layer = is_last_layer
|
|
|
+ self.shard = shard
|
|
|
self.vocab_size = config.vocab_size
|
|
|
self.model_type = config.model_type
|
|
|
self.num_hidden_layers = config.num_hidden_layers
|
|
@@ -381,10 +381,14 @@ class Llama(nn.Module):
|
|
|
self.head_dim = config.head_dim
|
|
|
assert self.vocab_size > 0
|
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
|
|
- self.layers = [
|
|
|
- TransformerBlock(config=config) for _ in range(config.num_hidden_layers)
|
|
|
- ]
|
|
|
- self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
+ self.layers = []
|
|
|
+ for i in range(self.num_hidden_layers):
|
|
|
+ if self.shard.start_layer <= i <= self.shard.end_layer:
|
|
|
+ self.layers.append(TransformerBlock(config=config))
|
|
|
+ else:
|
|
|
+ self.layers.append(IdentityBlock())
|
|
|
+ if self.shard.is_last_layer():
|
|
|
+ self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
|
|
def __call__(
|
|
|
self,
|
|
@@ -394,7 +398,7 @@ class Llama(nn.Module):
|
|
|
):
|
|
|
# for passing merged input embeddings
|
|
|
if inputs_embeds is None:
|
|
|
- if self.is_first_layer:
|
|
|
+ if self.shard.is_first_layer():
|
|
|
h = self.embed_tokens(inputs)
|
|
|
else:
|
|
|
h = inputs
|
|
@@ -413,20 +417,20 @@ class Llama(nn.Module):
|
|
|
for layer, c in zip(self.layers, cache):
|
|
|
h = layer(h, mask, c)
|
|
|
|
|
|
- if self.is_last_layer:
|
|
|
+ if self.shard.is_last_layer():
|
|
|
h = self.norm(h)
|
|
|
return h
|
|
|
|
|
|
class LanguageModel(nn.Module):
|
|
|
- def __init__(self, config: TextConfig, is_first_layer, is_last_layer):
|
|
|
+ def __init__(self, config: TextConfig, shard: Shard):
|
|
|
super().__init__()
|
|
|
self.model_type = config.model_type
|
|
|
if self.model_type != "llama":
|
|
|
raise ValueError(
|
|
|
f"Model type {self.model_type} not supported. Currently only 'llama' is supported"
|
|
|
)
|
|
|
- self.is_last_layer = is_last_layer
|
|
|
- self.model = Llama(config, is_first_layer, is_last_layer)
|
|
|
+ self.shard = shard
|
|
|
+ self.model = Llama(config, shard)
|
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
|
|
|
def __call__(
|
|
@@ -436,7 +440,7 @@ class LanguageModel(nn.Module):
|
|
|
inputs_embeds=None,
|
|
|
):
|
|
|
out = self.model(inputs, cache, inputs_embeds)
|
|
|
- if self.is_last_layer:
|
|
|
+ if self.shard.is_last_layer():
|
|
|
out = self.lm_head(out)
|
|
|
return out
|
|
|
|
|
@@ -485,8 +489,6 @@ class ModelArgs(LlaVAConfig):
|
|
|
if not self.shard.is_first_layer():
|
|
|
self.vision_config = None
|
|
|
|
|
|
- self.text_config.num_hidden_layers = self.shard.get_layer_count()
|
|
|
-
|
|
|
|
|
|
class LlavaMultiModalProjector(nn.Module):
|
|
|
def __init__(self, config: LlaVAConfig):
|
|
@@ -516,7 +518,7 @@ class Model(nn.Module):
|
|
|
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
|
|
self.vision_feature_layer = config.vision_feature_layer
|
|
|
self.vision_feature_select_strategy = config.vision_feature_select_strategy
|
|
|
- self.language_model = LanguageModel(config.text_config, config.shard.is_first_layer(), config.shard.is_last_layer())
|
|
|
+ self.language_model = LanguageModel(config.text_config, config.shard)
|
|
|
|
|
|
def get_input_embeddings(
|
|
|
self,
|