浏览代码

Merge pull request #221 from exo-explore/qwen2.5

Qwen2.5
Alex Cheema 7 月之前
父节点
当前提交
8ad19a5f53
共有 4 个文件被更改,包括 153 次插入3 次删除
  1. 127 0
      exo/inference/mlx/models/qwen2.py
  2. 16 0
      exo/models.py
  3. 5 3
      test/test_tokenizers.py
  4. 5 0
      tinychat/examples/tinychat/index.html

+ 127 - 0
exo/inference/mlx/models/qwen2.py

@@ -0,0 +1,127 @@
+from dataclasses import dataclass, field
+
+import mlx.core as mx
+import mlx.nn as nn
+
+from mlx_lm.models.base import create_attention_mask
+from mlx_lm.models.qwen2 import TransformerBlock, ModelArgs
+
+from ...shard import Shard
+from .base import IdentityBlock
+
+
+@dataclass
+class ModelArgs(ModelArgs):
+  shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
+
+  def __post_init__(self):
+    super().__post_init__()  # Ensure parent initializations are respected
+
+    if isinstance(self.shard, Shard):
+      return
+    if not isinstance(self.shard, dict):
+      raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
+
+    self.shard = Shard(**self.shard)
+
+class Qwen2Model(nn.Module):
+  def __init__(self, args: ModelArgs):
+    super().__init__()
+    self.args = args
+    self.vocab_size = args.vocab_size
+    self.num_hidden_layers = args.num_hidden_layers
+    assert self.vocab_size > 0
+    if self.args.shard.is_first_layer():
+      self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
+    self.layers = []
+    for i in range(self.num_hidden_layers):
+      if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
+        self.layers.append(TransformerBlock(args=args))
+      else:
+        self.layers.append(IdentityBlock())
+    if self.args.shard.is_last_layer():
+      self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache=None,
+  ):
+    if self.args.shard.is_first_layer():
+      h = self.embed_tokens(inputs)
+    else:
+      h = inputs
+
+    mask = None
+    if h.shape[1] > 1:
+      mask = create_attention_mask(h, cache)
+
+    if cache is None:
+      cache = [None] * len(self.layers)
+
+    for layer, c in zip(self.layers, cache):
+      h = layer(h, mask, c)
+
+    if self.args.shard.is_last_layer():
+      h = self.norm(h)
+    return h
+
+
+class Model(nn.Module):
+  def __init__(self, args: ModelArgs):
+    super().__init__()
+    self.args = args
+    self.model_type = args.model_type
+    self.model = Qwen2Model(args)
+    if self.args.shard.is_last_layer():
+      if not args.tie_word_embeddings:
+        self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache=None,
+  ):
+    out = self.model(inputs, cache)
+    if self.args.shard.is_last_layer():
+      if self.args.tie_word_embeddings:
+        out = self.model.embed_tokens.as_linear(out)
+      else:
+        out = self.lm_head(out)
+    return out
+
+  def sanitize(self, weights):
+    shard_state_dict = {}
+
+    for key, value in weights.items():
+      if "self_attn.rotary_emb.inv_freq" in key:
+        continue
+      if key.startswith('model.layers.'):
+        layer_num = int(key.split('.')[2])
+        if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
+          shard_state_dict[key] = value
+      elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'):
+        shard_state_dict[key] = value
+      elif (self.args.shard.is_last_layer() and self.args.tie_word_embeddings) and key.startswith('model.embed_tokens'):
+        shard_state_dict[key] = value
+      elif (self.args.shard.is_last_layer() and not self.args.tie_word_embeddings) and key.startswith('lm_head'):
+        shard_state_dict[key] = value
+      elif self.args.shard.is_last_layer() and (key.startswith('model.norm')):
+        shard_state_dict[key] = value
+
+    if self.args.tie_word_embeddings:
+      shard_state_dict.pop("lm_head.weight", None)
+
+    return shard_state_dict
+
+  @property
+  def layers(self):
+    return self.model.layers
+
+  @property
+  def head_dim(self):
+    return self.args.hidden_size // self.args.num_attention_heads
+
+  @property
+  def n_kv_heads(self):
+    return self.args.num_key_value_heads

+ 16 - 0
exo/models.py

@@ -30,4 +30,20 @@ model_base_shards = {
   "deepseek-coder-v2-lite": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27),},
   "deepseek-coder-v2-lite": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27),},
   ### llava
   ### llava
   "llava-1.5-7b-hf": {"MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32),},
   "llava-1.5-7b-hf": {"MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32),},
+  ### qwen
+  "qwen-2.5-7b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-7B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),
+  },
+  "qwen-2.5-math-7b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Math-7B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),
+  },
+  "qwen-2.5-14b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-14B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=48),
+  },
+  "qwen-2.5-72b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-72B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
+  },
+  "qwen-2.5-math-72b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Math-72B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
+  },
 }
 }

+ 5 - 3
test/test_tokenizers.py

@@ -1,3 +1,5 @@
+import os
+import re
 from transformers import AutoTokenizer, AutoProcessor
 from transformers import AutoTokenizer, AutoProcessor
 from exo.models import model_base_shards
 from exo.models import model_base_shards
 
 
@@ -22,10 +24,10 @@ def test_tokenizer(name, tokenizer, verbose=False):
     strip_tokens = lambda s: s.lstrip(tokenizer.decode([tokenizer.bos_token_id])).rstrip(tokenizer.decode([tokenizer.eos_token_id]))
     strip_tokens = lambda s: s.lstrip(tokenizer.decode([tokenizer.bos_token_id])).rstrip(tokenizer.decode([tokenizer.eos_token_id]))
     assert text == strip_tokens(decoded) == strip_tokens(reconstructed)
     assert text == strip_tokens(decoded) == strip_tokens(reconstructed)
 
 
-ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "llava-hf/llava-1.5-7b-hf"]
-models = [shard.model_id for shards in model_base_shards.values() for shard in shards.values() if shard.model_id not in ignore]
+ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*"]
+ignore_pattern = re.compile(r"^(" + "|".join(model.replace("*", ".*") for model in ignore) + r")")
+models = [shard.model_id for shards in model_base_shards.values() for shard in shards.values() if not ignore_pattern.match(shard.model_id)]
 
 
-import os
 verbose = os.environ.get("VERBOSE", "0").lower() == "1"
 verbose = os.environ.get("VERBOSE", "0").lower() == "1"
 for m in models:
 for m in models:
     # TODO: figure out why use_fast=False is giving inconsistent behaviour (no spaces decoding invididual tokens) for Mistral-Large-Instruct-2407-4bit
     # TODO: figure out why use_fast=False is giving inconsistent behaviour (no spaces decoding invididual tokens) for Mistral-Large-Instruct-2407-4bit

+ 5 - 0
tinychat/examples/tinychat/index.html

@@ -37,6 +37,11 @@
 <option value="mistral-large">Mistral Large</option>
 <option value="mistral-large">Mistral Large</option>
 <option value="deepseek-coder-v2-lite">Deepseek Coder V2 Lite</option>
 <option value="deepseek-coder-v2-lite">Deepseek Coder V2 Lite</option>
 <option value="llava-1.5-7b-hf">LLaVa 1.5 7B (Vision Model)</option>
 <option value="llava-1.5-7b-hf">LLaVa 1.5 7B (Vision Model)</option>
+<option value="qwen-2.5-7b">Qwen 2.5 7B</option>
+<option value="qwen-2.5-math-7b">Qwen 2.5 7B (Math)</option>
+<option value="qwen-2.5-14b">Qwen 2.5 14B</option>
+<option value="qwen-2.5-72b">Qwen 2.5 72B</option>
+<option value="qwen-2.5-math-72b">Qwen 2.5 72B (Math)</option>
 </select>
 </select>
 </div>
 </div>
 <div @popstate.window="
 <div @popstate.window="