Quellcode durchsuchen

add deepseek v1, v3 and all the distills

Alex Cheema vor 6 Monaten
Ursprung
Commit
d8ffa59dba
3 geänderte Dateien mit 203 neuen und 2 gelöschten Zeilen
  1. 135 0
      exo/inference/mlx/models/deepseek_v3.py
  2. 66 0
      exo/models.py
  3. 2 2
      setup.py

+ 135 - 0
exo/inference/mlx/models/deepseek_v3.py

@@ -0,0 +1,135 @@
+from dataclasses import dataclass, field
+from typing import Optional
+
+import mlx.core as mx
+import mlx.nn as nn
+
+from mlx_lm.models.cache import KVCache
+from mlx_lm.models.deepseek_v3 import (
+  ModelArgs as V3ModelArgs,
+  DeepseekV3DecoderLayer,
+)
+from .base import IdentityBlock
+from exo.inference.shard import Shard
+
+
+@dataclass
+class ModelArgs(V3ModelArgs):
+  shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
+
+  def __post_init__(self):
+    super().__post_init__()
+    if isinstance(self.shard, Shard):
+      return
+    if not isinstance(self.shard, dict):
+      raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
+
+    self.shard = Shard(**self.shard)
+
+
+class DeepseekV3Model(nn.Module):
+  def __init__(self, config: ModelArgs):
+    super().__init__()
+    self.args = config
+    self.num_hidden_layers = config.num_hidden_layers
+    self.vocab_size = config.vocab_size
+    if self.args.shard.is_first_layer():
+      self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
+
+    self.layers = []
+    for i in range(self.num_hidden_layers):
+      if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
+        self.layers.append(DeepseekV3DecoderLayer(config, i))
+      else:
+        self.layers.append(IdentityBlock())
+
+    if self.args.shard.is_last_layer():
+      self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+  def __call__(
+    self,
+    x: mx.array,
+    cache: Optional[KVCache] = None,
+  ) -> mx.array:
+    if self.args.shard.is_first_layer():
+      h = self.embed_tokens(x)
+    else:
+      h = x
+
+    mask = None
+    T = h.shape[1]
+    if T > 1:
+      mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
+      mask = mask.astype(h.dtype)
+
+    if cache is None:
+      cache = [None]*len(self.layers)
+
+    for layer, c in zip(self.layers, cache):
+      h = layer(h, mask, c)
+
+    if self.args.shard.is_last_layer():
+      h = self.norm(h)
+    return h
+
+
+class Model(nn.Module):
+  def __init__(self, config: ModelArgs):
+    super().__init__()
+    self.args = config
+    self.model_type = config.model_type
+    self.model = DeepseekV3Model(config)
+    if self.args.shard.is_last_layer():
+      self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache: Optional[KVCache] = None,
+  ):
+    out = self.model(inputs, cache)
+    if self.args.shard.is_last_layer():
+      return self.lm_head(out)
+    return out
+
+  def sanitize(self, weights):
+    shard_state_dict = {}
+
+    for key, value in weights.items():
+      if key.startswith('model.layers.'):
+        layer_num = int(key.split('.')[2])
+        if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
+          shard_state_dict[key] = value
+      elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'):
+        shard_state_dict[key] = value
+      elif self.args.shard.is_last_layer() and (key.startswith('model.norm') or key.startswith('lm_head')):
+        shard_state_dict[key] = value
+
+    for l in range(self.args.num_hidden_layers):
+      prefix = f"model.layers.{l}"
+      for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
+        for k in ["weight", "scales", "biases"]:
+          expert_key = f"{prefix}.mlp.experts.0.{m}.{k}"
+          if expert_key in shard_state_dict:
+            to_join = [
+              shard_state_dict.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
+              for e in range(self.args.n_routed_experts)
+            ]
+            shard_state_dict[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
+
+    return shard_state_dict
+
+  @property
+  def layers(self):
+    return self.model.layers
+
+  @property
+  def head_dim(self):
+    return (
+      self.args.qk_nope_head_dim + self.args.qk_rope_head_dim,
+      self.args.v_head_dim,
+    )
+
+  @property
+  def n_kv_heads(self):
+    return self.args.num_key_value_heads

+ 66 - 0
exo/models.py

@@ -88,6 +88,38 @@ model_cards = {
   ### deepseek
   "deepseek-coder-v2-lite": { "layers": 27, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", }, },
   "deepseek-coder-v2.5": { "layers": 60, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", }, },
+  "deepseek-v3": { "layers": 61, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-V3-4bit", }, },
+  "deepseek-r1": { "layers": 61, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-4bit", }, },
+  ### deepseek distills
+  "deepseek-r1-distill-qwen-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/deepseek-r1-distill-qwen-1.5b", }, },
+  "deepseek-r1-distill-qwen-1.5b-3bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-3bit", }, },
+  "deepseek-r1-distill-qwen-1.5b-6bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-6bit", }, },
+  "deepseek-r1-distill-qwen-1.5b-8bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-8bit", }, },
+  "deepseek-r1-distill-qwen-1.5b-bf16": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-bf16", }, },
+  "deepseek-r1-distill-qwen-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-7B-4bit", }, },
+  "deepseek-r1-distill-qwen-7b-3bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-7B-3bit", }, },
+  "deepseek-r1-distill-qwen-7b-6bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-7B-6bit", }, },
+  "deepseek-r1-distill-qwen-7b-8bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-7B-8bit", }, },
+  "deepseek-r1-distill-qwen-7b-bf16": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-7B-bf16", }, },
+  "deepseek-r1-distill-qwen-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-14B-4bit", }, },
+  "deepseek-r1-distill-qwen-14b-3bit": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-14B-3bit", }, },
+  "deepseek-r1-distill-qwen-14b-6bit": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-14B-6bit", }, },
+  "deepseek-r1-distill-qwen-14b-8bit": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-14B-8bit", }, },
+  "deepseek-r1-distill-qwen-14b-bf16": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-14B-bf16", }, },
+  "deepseek-r1-distill-qwen-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-32B-4bit", }, },
+  "deepseek-r1-distill-qwen-32b-3bit": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-32B-3bit", }, },
+  "deepseek-r1-distill-qwen-32b-6bit": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-32B-6bit", }, },
+  "deepseek-r1-distill-qwen-32b-8bit": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-32B-MLX-8Bit", }, },
+  "deepseek-r1-distill-qwen-32b-bf16": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-32B-bf16", }, },
+  "deepseek-r1-distill-llama-8b": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-8B-4bit", }, },
+  "deepseek-r1-distill-llama-8b-3bit": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-8B-3bit", }, },
+  "deepseek-r1-distill-llama-8b-6bit": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-8B-6bit", }, },
+  "deepseek-r1-distill-llama-8b-8bit": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-8B-8bit", }, },
+  "deepseek-r1-distill-llama-8b-bf16": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-8B-bf16", }, },
+  "deepseek-r1-distill-llama-70b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-70B-4bit", }, },
+  "deepseek-r1-distill-llama-70b-3bit": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-70B-3bit", }, },
+  "deepseek-r1-distill-llama-70b-6bit": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-70B-6bit", }, },
+  "deepseek-r1-distill-llama-70b-8bit": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-70B-8bit", }, },
   ### llava
   "llava-1.5-7b-hf": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "llava-hf/llava-1.5-7b-hf", }, },
   ### qwen
@@ -140,6 +172,8 @@ pretty_name = {
   "mistral-large": "Mistral Large",
   "deepseek-coder-v2-lite": "Deepseek Coder V2 Lite",
   "deepseek-coder-v2.5": "Deepseek Coder V2.5",
+  "deepseek-v3": "Deepseek V3",
+  "deepseek-r1": "Deepseek R1",
   "llava-1.5-7b-hf": "LLaVa 1.5 7B (Vision Model)",
   "qwen-2.5-1.5b": "Qwen 2.5 1.5B",
   "qwen-2.5-coder-1.5b": "Qwen 2.5 Coder 1.5B",
@@ -159,6 +193,38 @@ pretty_name = {
   "llama-3-8b": "Llama 3 8B",
   "llama-3-70b": "Llama 3 70B",
   "stable-diffusion-2-1-base": "Stable Diffusion 2.1",
+  "deepseek-r1-distill-qwen-1.5b": "DeepSeek R1 Distill Qwen 1.5B",
+  "deepseek-r1-distill-qwen-1.5b-3bit": "DeepSeek R1 Distill Qwen 1.5B (3-bit)",
+  "deepseek-r1-distill-qwen-1.5b-6bit": "DeepSeek R1 Distill Qwen 1.5B (6-bit)",
+  "deepseek-r1-distill-qwen-1.5b-8bit": "DeepSeek R1 Distill Qwen 1.5B (8-bit)",
+  "deepseek-r1-distill-qwen-1.5b-bf16": "DeepSeek R1 Distill Qwen 1.5B (BF16)",
+  "deepseek-r1-distill-qwen-7b": "DeepSeek R1 Distill Qwen 7B",
+  "deepseek-r1-distill-qwen-7b-3bit": "DeepSeek R1 Distill Qwen 7B (3-bit)",
+  "deepseek-r1-distill-qwen-7b-6bit": "DeepSeek R1 Distill Qwen 7B (6-bit)",
+  "deepseek-r1-distill-qwen-7b-8bit": "DeepSeek R1 Distill Qwen 7B (8-bit)",
+  "deepseek-r1-distill-qwen-7b-bf16": "DeepSeek R1 Distill Qwen 7B (BF16)",
+  "deepseek-r1-distill-qwen-14b": "DeepSeek R1 Distill Qwen 14B",
+  "deepseek-r1-distill-qwen-14b-3bit": "DeepSeek R1 Distill Qwen 14B (3-bit)",
+  "deepseek-r1-distill-qwen-14b-6bit": "DeepSeek R1 Distill Qwen 14B (6-bit)",
+  "deepseek-r1-distill-qwen-14b-8bit": "DeepSeek R1 Distill Qwen 14B (8-bit)",
+  "deepseek-r1-distill-qwen-14b-bf16": "DeepSeek R1 Distill Qwen 14B (BF16)",
+  "deepseek-r1-distill-qwen-32b": "DeepSeek R1 Distill Qwen 32B",
+  "deepseek-r1-distill-qwen-32b-3bit": "DeepSeek R1 Distill Qwen 32B (3-bit)",
+  "deepseek-r1-distill-qwen-32b-8bit": "DeepSeek R1 Distill Qwen 32B (8-bit)",
+  "deepseek-r1-distill-qwen-32b-bf16": "DeepSeek R1 Distill Qwen 32B (BF16)",
+  "deepseek-r1-distill-llama-8b-8bit": "DeepSeek R1 Distill Llama 8B (8-bit)",
+  "deepseek-r1-distill-llama-70b-6bit": "DeepSeek R1 Distill Llama 70B (6-bit)",
+  "deepseek-r1-distill-llama-70b-8bit": "DeepSeek R1 Distill Llama 70B (8-bit)",
+  "deepseek-r1-distill-llama-8b": "DeepSeek R1 Distill Llama 8B",
+  "deepseek-r1-distill-llama-8b-3bit": "DeepSeek R1 Distill Llama 8B (3-bit)",
+  "deepseek-r1-distill-llama-8b-6bit": "DeepSeek R1 Distill Llama 8B (6-bit)",
+  "deepseek-r1-distill-llama-8b-8bit": "DeepSeek R1 Distill Llama 8B (8-bit)",
+  "deepseek-r1-distill-llama-8b-bf16": "DeepSeek R1 Distill Llama 8B (BF16)",
+  "deepseek-r1-distill-llama-70b": "DeepSeek R1 Distill Llama 70B",
+  "deepseek-r1-distill-llama-70b-3bit": "DeepSeek R1 Distill Llama 70B (3-bit)",
+  "deepseek-r1-distill-llama-70b-6bit": "DeepSeek R1 Distill Llama 70B (6-bit)",
+  "deepseek-r1-distill-llama-70b-8bit": "DeepSeek R1 Distill Llama 70B (8-bit)",
+  "deepseek-r1-distill-qwen-32b-6bit": "DeepSeek R1 Distill Qwen 32B (6-bit)",
 }
 
 def get_repo(model_id: str, inference_engine_classname: str) -> Optional[str]:

+ 2 - 2
setup.py

@@ -35,8 +35,8 @@ install_requires = [
 extras_require = {
   "formatting": ["yapf==0.40.2",],
   "apple_silicon": [
-    "mlx==0.21.1",
-    "mlx-lm==0.20.4",
+    "mlx==0.22.0",
+    "mlx-lm==0.21.1",
   ],
   "windows": ["pywin32==308",],
   "nvidia-gpu": ["nvidia-ml-py==12.560.30",],