瀏覽代碼

Merge pull request #596 from exo-explore/phi4

add phi 3.5, phi 4
Alex Cheema 3 月之前
父節點
當前提交
2d631ea53d
共有 4 個文件被更改,包括 127 次插入4 次删除
  1. 117 0
      exo/inference/mlx/models/phi3.py
  2. 4 3
      exo/inference/mlx/models/qwen2.py
  3. 5 0
      exo/models.py
  4. 1 1
      test/test_tokenizers.py

+ 117 - 0
exo/inference/mlx/models/phi3.py

@@ -0,0 +1,117 @@
+from dataclasses import dataclass, field
+
+import mlx.core as mx
+import mlx.nn as nn
+
+from mlx_lm.models.base import create_attention_mask
+from mlx_lm.models.phi3 import TransformerBlock, ModelArgs
+
+from ...shard import Shard
+from .base import IdentityBlock
+
+@dataclass
+class ModelArgs(ModelArgs):
+  shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
+
+  def __post_init__(self):
+    super().__post_init__()
+
+    if isinstance(self.shard, Shard):
+      return
+    if not isinstance(self.shard, dict):
+      raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
+
+    self.shard = Shard(**self.shard)
+
+class Phi3Model(nn.Module):
+  def __init__(self, args: ModelArgs):
+    super().__init__()
+    self.args = args
+    self.vocab_size = args.vocab_size
+    self.num_hidden_layers = args.num_hidden_layers
+    assert self.vocab_size > 0
+    
+    if self.args.shard.is_first_layer():
+      self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
+    
+    self.layers = []
+    for i in range(self.num_hidden_layers):
+      if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
+        self.layers.append(TransformerBlock(args=args))
+      else:
+        self.layers.append(IdentityBlock())
+        
+    if self.args.shard.is_last_layer():
+      self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache=None,
+  ):
+    if self.args.shard.is_first_layer():
+      h = self.embed_tokens(inputs)
+    else:
+      h = inputs
+
+    mask = None
+    if h.shape[1] > 1:
+      mask = create_attention_mask(h, cache)
+
+    if cache is None:
+      cache = [None] * len(self.layers)
+
+    for layer, c in zip(self.layers, cache):
+      h = layer(h, mask, c)
+
+    if self.args.shard.is_last_layer():
+      h = self.norm(h)
+    return h
+
+class Model(nn.Module):
+  def __init__(self, args: ModelArgs):
+    super().__init__()
+    self.args = args
+    self.model_type = args.model_type
+    self.model = Phi3Model(args)
+    if self.args.shard.is_last_layer():
+      self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache=None,
+  ):
+    out = self.model(inputs, cache)
+    if self.args.shard.is_last_layer():
+      out = self.lm_head(out)
+    return out
+
+  def sanitize(self, weights):
+    shard_state_dict = {}
+
+    for key, value in weights.items():
+      if "self_attn.rope.inv_freq" in key:
+        continue
+      if key.startswith('model.layers.'):
+        layer_num = int(key.split('.')[2])
+        if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
+          shard_state_dict[key] = value
+      elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'):
+        shard_state_dict[key] = value
+      elif self.args.shard.is_last_layer() and (key.startswith('lm_head') or key.startswith('model.norm')):
+        shard_state_dict[key] = value
+
+    return shard_state_dict
+
+  @property
+  def layers(self):
+    return self.model.layers
+
+  @property
+  def head_dim(self):
+    return self.args.hidden_size // self.args.num_attention_heads
+
+  @property
+  def n_kv_heads(self):
+    return self.args.num_key_value_heads

+ 4 - 3
exo/inference/mlx/models/qwen2.py

@@ -9,13 +9,12 @@ from mlx_lm.models.qwen2 import TransformerBlock, ModelArgs
 from ...shard import Shard
 from .base import IdentityBlock
 
-
 @dataclass
 class ModelArgs(ModelArgs):
   shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
 
   def __post_init__(self):
-    super().__post_init__()  # Ensure parent initializations are respected
+    super().__post_init__()
 
     if isinstance(self.shard, Shard):
       return
@@ -24,7 +23,6 @@ class ModelArgs(ModelArgs):
 
     self.shard = Shard(**self.shard)
 
-
 class Qwen2Model(nn.Module):
   def __init__(self, args: ModelArgs):
     super().__init__()
@@ -32,14 +30,17 @@ class Qwen2Model(nn.Module):
     self.vocab_size = args.vocab_size
     self.num_hidden_layers = args.num_hidden_layers
     assert self.vocab_size > 0
+
     if self.args.shard.is_first_layer():
       self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
+
     self.layers = []
     for i in range(self.num_hidden_layers):
       if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
         self.layers.append(TransformerBlock(args=args))
       else:
         self.layers.append(IdentityBlock())
+
     if self.args.shard.is_last_layer():
       self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
 

+ 5 - 0
exo/models.py

@@ -111,6 +111,9 @@ model_cards = {
   # gemma
   "gemma2-9b": { "layers": 42, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-9b-it-4bit", }, },
   "gemma2-27b": { "layers": 46, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-27b-it-4bit", }, },
+  # phi
+  "phi-3.5-mini": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Phi-3.5-mini-instruct-4bit", }, },
+  "phi-4": { "layers": 40, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/phi-4-4bit", }, },
   # dummy
   "dummy": { "layers": 8, "repo": { "DummyInferenceEngine": "dummy", }, },
 }
@@ -149,6 +152,8 @@ pretty_name = {
   "qwen-2.5-coder-32b": "Qwen 2.5 Coder 32B",
   "qwen-2.5-72b": "Qwen 2.5 72B",
   "qwen-2.5-math-72b": "Qwen 2.5 72B (Math)",
+  "phi-3.5-mini": "Phi-3.5 Mini",
+  "phi-4": "Phi-4",
   "llama-3-8b": "Llama 3 8B",
   "llama-3-70b": "Llama 3 70B",
 }

+ 1 - 1
test/test_tokenizers.py

@@ -24,7 +24,7 @@ def test_tokenizer(name, tokenizer, verbose=False):
     strip_tokens = lambda s: s.lstrip(tokenizer.decode([tokenizer.bos_token_id])).rstrip(tokenizer.decode([tokenizer.eos_token_id]))
     assert text == strip_tokens(decoded) == strip_tokens(reconstructed)
 
-ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy", "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit"]
+ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy", "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit", "mlx-community/Phi-3.5-Mini-Instruct-4bit", "mlx-community/Phi-4-4bit"]
 ignore_pattern = re.compile(r"^(" + "|".join(model.replace("*", ".*") for model in ignore) + r")")
 models = []
 for model_id in model_cards: