Bladeren bron

add Gemma2 9b and Gemma2 27bg

Alex Cheema 5 maanden geleden
bovenliggende
commit
fcaebd3b50
3 gewijzigde bestanden met toevoegingen van 125 en 2 verwijderingen
  1. 118 0
      exo/inference/mlx/models/gemma2.py
  2. 3 0
      exo/models.py
  3. 4 2
      exo/tinychat/index.html

+ 118 - 0
exo/inference/mlx/models/gemma2.py

@@ -0,0 +1,118 @@
+from dataclasses import dataclass, field
+
+import mlx.core as mx
+import mlx.nn as nn
+
+from mlx_lm.models.base import create_attention_mask
+from mlx_lm.models.gemma2 import TransformerBlock, ModelArgs, RMSNorm
+
+from ...shard import Shard
+from .base import IdentityBlock
+
+
+@dataclass
+class ModelArgs(ModelArgs):
+  shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
+
+  def __post_init__(self):
+    if isinstance(self.shard, Shard):
+      return
+    if not isinstance(self.shard, dict):
+      raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
+
+    self.shard = Shard(**self.shard)
+
+
+class GemmaModel(nn.Module):
+  def __init__(self, args: ModelArgs):
+    super().__init__()
+    self.args = args
+    self.vocab_size = args.vocab_size
+    self.num_hidden_layers = args.num_hidden_layers
+    assert self.vocab_size > 0
+    if args.shard.is_first_layer() or args.shard.is_last_layer():
+      self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
+    self.layers = []
+    for i in range(self.num_hidden_layers):
+      if args.shard.start_layer <= i <= args.shard.end_layer:
+        self.layers.append(TransformerBlock(args=args))
+      else:
+        self.layers.append(IdentityBlock())
+    if args.shard.is_last_layer():
+      self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache=None,
+  ):
+    if self.args.shard.is_first_layer():
+      h = self.embed_tokens(inputs)
+      h = h * (self.args.hidden_size**0.5)
+    else:
+      h = inputs
+
+    mask = None
+    if h.ndim > 1 and h.shape[1] > 1:
+      mask = create_attention_mask(h, cache)
+
+    if cache is None:
+      cache = [None]*len(self.layers)
+
+    for layer, c in zip(self.layers, cache):
+      h = layer(h, mask, cache=c)
+
+    if self.args.shard.is_last_layer():
+      h = self.norm(h)
+    return h
+
+
+class Model(nn.Module):
+  def __init__(self, args: ModelArgs):
+    super().__init__()
+    self.args = args
+    self.model_type = args.model_type
+    self.model = GemmaModel(args)
+    if args.shard.is_last_layer():
+      self.final_logit_softcapping = args.final_logit_softcapping
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache=None,
+  ):
+    out = self.model(inputs, cache)
+    if self.args.shard.is_last_layer():
+      out = self.model.embed_tokens.as_linear(out)
+      out = mx.tanh(out / self.final_logit_softcapping)
+      out = out * self.final_logit_softcapping
+    return out
+
+  def sanitize(self, weights):
+    shard_state_dict = {}
+
+    for key, value in weights.items():
+      if "self_attn.rotary_emb.inv_freq" in key:
+        continue
+      if key.startswith('model.layers.'):
+        layer_num = int(key.split('.')[2])
+        if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
+          shard_state_dict[key] = value
+      elif (self.args.shard.is_first_layer() or self.args.shard.is_last_layer()) and key.startswith('model.embed_tokens'):
+        shard_state_dict[key] = value
+      elif self.args.shard.is_last_layer() and (key.startswith('model.norm')):
+        shard_state_dict[key] = value
+
+    return shard_state_dict
+
+  @property
+  def layers(self):
+    return self.model.layers
+
+  @property
+  def head_dim(self):
+    return self.args.head_dim
+
+  @property
+  def n_kv_heads(self):
+    return self.args.num_key_value_heads

+ 3 - 0
exo/models.py

@@ -45,6 +45,9 @@ model_base_shards = {
   ### nemotron
   "nemotron-70b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/nvidia_Llama-3.1-Nemotron-70B-Instruct-HF_4bit", start_layer=0, end_layer=0, n_layers=80),},
   "nemotron-70b-bf16": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.1-Nemotron-70B-Instruct-HF-bf16", start_layer=0, end_layer=0, n_layers=80),},
+  # gemma
+  "gemma2-9b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/gemma-2-9b-it-4bit", start_layer=0, end_layer=0, n_layers=42),},
+  "gemma2-27b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/gemma-2-27b-it-4bit", start_layer=0, end_layer=0, n_layers=46),},
   # dummy
   "dummy": {"DummyInferenceEngine": Shard(model_id="dummy", start_layer=0, end_layer=7, n_layers=8),},
 }

+ 4 - 2
exo/tinychat/index.html

@@ -37,8 +37,8 @@
 <option value="llama-3.1-70b-bf16">Llama 3.1 70B (BF16)</option>
 <option value="llama-3.1-405b">Llama 3.1 405B</option>
 <option value="llama-3.1-405b-8bit">Llama 3.1 405B (8-bit)</option>
-<option value="llama-3-8b">Llama 3 8B</option>
-<option value="llama-3-70b">Llama 3 70B</option>
+<option value="gemma2-9b">Gemma2 9B</option>
+<option value="gemma2-27b">Gemma2 27B</option>
 <option value="nemotron-70b">Nemotron 70B</option>
 <option value="nemotron-70b-bf16">Nemotron 70B (BF16)</option>
 <option value="mistral-nemo">Mistral Nemo</option>
@@ -53,6 +53,8 @@
 <option value="qwen-2.5-14b">Qwen 2.5 14B</option>
 <option value="qwen-2.5-72b">Qwen 2.5 72B</option>
 <option value="qwen-2.5-math-72b">Qwen 2.5 72B (Math)</option>
+<option value="llama-3-8b">Llama 3 8B</option>
+<option value="llama-3-70b">Llama 3 70B</option>
 </select>
 </div>
 <div @popstate.window="