Browse Source

Merge pull request #94 from mzbac/mlx_refactor

refactor(mlx): model sharding and add deepseek v2 support
Alex Cheema 11 months ago
parent
commit
044d189ccc

+ 4 - 0
exo/api/chatgpt_api.py

@@ -38,6 +38,10 @@ shard_mappings = {
   "mistral-large": {
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),
   },
+  ### deepseek v2
+  "deepseek-coder-v2-lite": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27),
+  },
 }
 
 

+ 9 - 0
exo/inference/mlx/models/base.py

@@ -0,0 +1,9 @@
+from typing import Optional
+import mlx.core as mx
+import mlx.nn as nn
+from mlx_lm.models.base import KVCache
+
+
+class IdentityBlock(nn.Module):
+  def __call__(self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[KVCache] = None) -> mx.array:
+    return x

+ 130 - 0
exo/inference/mlx/models/deepseek_v2.py

@@ -0,0 +1,130 @@
+from dataclasses import dataclass, field
+from typing import Optional
+
+import mlx.core as mx
+import mlx.nn as nn
+
+from mlx_lm.models.base import KVCache
+from mlx_lm.models.deepseek_v2 import ModelArgs, DeepseekV2DecoderLayer
+from .base import IdentityBlock
+from ...shard import Shard
+
+
+@dataclass
+class ModelArgs(ModelArgs):
+  shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
+
+  def __post_init__(self):
+    if isinstance(self.shard, Shard):
+      return
+    if not isinstance(self.shard, dict):
+      raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
+
+    self.shard = Shard(**self.shard)
+
+
+class DeepseekV2Model(nn.Module):
+  def __init__(self, config: ModelArgs):
+    super().__init__()
+    self.args = config
+    self.num_hidden_layers = config.num_hidden_layers
+    self.vocab_size = config.vocab_size
+    if self.args.shard.is_first_layer():
+      self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
+
+    self.layers = []
+    for i in range(self.num_hidden_layers):
+      if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
+        self.layers.append(DeepseekV2DecoderLayer(config, i))
+      else:
+        self.layers.append(IdentityBlock())
+
+    if self.args.shard.is_last_layer():
+      self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+  def __call__(
+    self,
+    x: mx.array,
+    cache: Optional[KVCache] = None,
+  ) -> mx.array:
+    if self.args.shard.is_first_layer():
+      h = self.embed_tokens(x)
+    else:
+      h = x
+
+    mask = None
+    T = h.shape[1]
+    if T > 1:
+      mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
+      mask = mask.astype(h.dtype)
+
+    if cache is None:
+      cache = [None] * len(self.layers)
+
+    for layer, c in zip(self.layers, cache):
+      h = layer(h, mask, c)
+
+    if self.args.shard.is_last_layer():
+      h = self.norm(h)
+    return h
+
+
+class Model(nn.Module):
+  def __init__(self, config: ModelArgs):
+    super().__init__()
+    self.args = config
+    self.model_type = config.model_type
+    self.model = DeepseekV2Model(config)
+    if self.args.shard.is_last_layer():
+      self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache: Optional[KVCache] = None,
+  ):
+    out = self.model(inputs, cache)
+    if self.args.shard.is_last_layer():
+      return self.lm_head(out)
+    return out
+
+  def sanitize(self, weights):
+    shard_state_dict = {}
+
+    for key, value in weights.items():
+      if key.startswith('model.layers.'):
+        layer_num = int(key.split('.')[2])
+        if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
+          shard_state_dict[key] = value
+      elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'):
+        shard_state_dict[key] = value
+      elif self.args.shard.is_last_layer() and (key.startswith('model.norm') or key.startswith('lm_head')):
+        shard_state_dict[key] = value
+
+    for l in range(self.args.num_hidden_layers):
+      prefix = f"model.layers.{l}"
+      for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
+        for k in ["weight", "scales", "biases"]:
+          if f"{prefix}.mlp.experts.0.{m}.{k}" in shard_state_dict:
+            to_join = [shard_state_dict.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}") for e in range(self.args.n_routed_experts)]
+            shard_state_dict[
+              f"{prefix}.mlp.switch_mlp.{
+       m}.{k}"
+            ] = mx.stack(to_join)
+
+    return shard_state_dict
+
+  @property
+  def layers(self):
+    return self.model.layers
+
+  @property
+  def head_dim(self):
+    return (
+      self.args.qk_nope_head_dim + self.args.qk_rope_head_dim,
+      self.args.v_head_dim,
+    )
+
+  @property
+  def n_kv_heads(self):
+    return self.args.num_key_value_heads

+ 124 - 0
exo/inference/mlx/models/llama.py

@@ -0,0 +1,124 @@
+from dataclasses import dataclass, field
+
+import mlx.core as mx
+import mlx.nn as nn
+
+from mlx_lm.models.base import create_additive_causal_mask
+from mlx_lm.models.llama import TransformerBlock, ModelArgs
+
+from ...shard import Shard
+from .base import IdentityBlock
+
+
+@dataclass
+class ModelArgs(ModelArgs):
+  shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
+
+  def __post_init__(self):
+    super().__post_init__()  # Ensure parent initializations are respected
+
+    if isinstance(self.shard, Shard):
+      return
+    if not isinstance(self.shard, dict):
+      raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
+
+    self.shard = Shard(**self.shard)
+
+
+class LlamaModel(nn.Module):
+  def __init__(self, args: ModelArgs):
+    super().__init__()
+    self.args = args
+    self.vocab_size = args.vocab_size
+    self.num_hidden_layers = args.num_hidden_layers
+    assert self.vocab_size > 0
+    if self.args.shard.is_first_layer():
+      self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
+    self.layers = []
+    for i in range(self.num_hidden_layers):
+      if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
+        self.layers.append(TransformerBlock(args=args))
+      else:
+        self.layers.append(IdentityBlock())
+
+    if self.args.shard.is_last_layer():
+      self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache=None,
+  ):
+    if self.args.shard.is_first_layer():
+      h = self.embed_tokens(inputs)
+    else:
+      h = inputs
+
+    mask = None
+    if h.shape[1] > 1:
+      mask = create_additive_causal_mask(h.shape[1], cache[0].offset if cache is not None else 0)
+      mask = mask.astype(h.dtype)
+
+    if cache is None:
+      cache = [None] * len(self.layers)
+
+    for layer, c in zip(self.layers, cache):
+      h = layer(h, mask, cache=c)
+
+    if self.args.shard.is_last_layer():
+      h = self.norm(h)
+    return h
+
+
+class Model(nn.Module):
+  def __init__(self, args: ModelArgs):
+    super().__init__()
+    self.args = args
+    self.model_type = args.model_type
+
+    self.model = LlamaModel(args)
+    if self.args.shard.is_last_layer():
+      if not args.tie_word_embeddings:
+        self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache=None,
+  ):
+    out = self.model(inputs, cache)
+    if self.args.shard.is_last_layer():
+      if self.args.tie_word_embeddings:
+        out = self.model.embed_tokens.as_linear(out)
+      else:
+        out = self.lm_head(out)
+    return out
+
+  def sanitize(self, weights):
+    shard_state_dict = {}
+
+    for key, value in weights.items():
+      if "self_attn.rotary_emb.inv_freq" in key:
+        continue
+      if key.startswith('model.layers.'):
+        layer_num = int(key.split('.')[2])
+        if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
+          shard_state_dict[key] = value
+      elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'):
+        shard_state_dict[key] = value
+      elif self.args.shard.is_last_layer() and (key.startswith('model.norm') or key.startswith('lm_head')):
+        shard_state_dict[key] = value
+
+    return shard_state_dict
+
+  @property
+  def layers(self):
+    return self.model.layers
+
+  @property
+  def head_dim(self):
+    return self.args.hidden_size // self.args.num_attention_heads
+
+  @property
+  def n_kv_heads(self):
+    return self.args.num_key_value_heads

+ 0 - 324
exo/inference/mlx/models/sharded_llama.py

@@ -1,324 +0,0 @@
-from dataclasses import dataclass, field
-from typing import Dict, Optional, Union
-
-import mlx.core as mx
-import mlx.nn as nn
-
-from exo.inference.shard import Shard
-from mlx_lm.models.base import BaseModelArgs, KVCache, create_additive_causal_mask
-
-
-@dataclass
-class NormalModelArgs(BaseModelArgs):
-  model_type: str
-  hidden_size: int
-  num_hidden_layers: int
-  intermediate_size: int
-  num_attention_heads: int
-  rms_norm_eps: float
-  vocab_size: int
-  head_dim: Optional[int] = None
-  max_position_embeddings: Optional[int] = None
-  num_key_value_heads: Optional[int] = None
-  attention_bias: bool = False
-  mlp_bias: bool = False
-  rope_theta: float = 10000
-  rope_traditional: bool = False
-  rope_scaling: Optional[Dict[str, Union[float, str]]] = None
-  tie_word_embeddings: bool = True
-
-  def __post_init__(self):
-    if self.num_key_value_heads is None:
-      self.num_key_value_heads = self.num_attention_heads
-
-    if self.rope_scaling:
-      if "factor" not in self.rope_scaling:
-        raise ValueError("rope_scaling must contain 'factor'")
-      rope_type = self.rope_scaling.get("type") or self.rope_scaling.get("rope_type")
-      if rope_type is None:
-        raise ValueError("rope_scaling must contain either 'type' or 'rope_type'")
-      if rope_type not in ["linear", "dynamic", "llama3"]:
-        raise ValueError("rope_scaling 'type' currently only supports 'linear', 'dynamic' or 'llama3'")
-
-
-@dataclass
-class ModelArgs(NormalModelArgs):
-  shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
-
-  def __post_init__(self):
-    super().__post_init__()  # Ensure parent initializations are respected
-
-    if isinstance(self.shard, Shard):
-      return
-    if not isinstance(self.shard, dict):
-      raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
-
-    self.shard = Shard(**self.shard)
-
-
-class DynamicNTKScalingRoPE(nn.Module):
-  """Implements the rotary positional encoding with Dynamic NTK scaling and Llama 3 RoPE."""
-
-  def __init__(
-    self,
-    dims: int,
-    max_position_embeddings: int = 2048,
-    traditional: bool = False,
-    base: float = 10000,
-    scale: float = 1.0,
-    rope_type: str = "default",
-    rope_scaling: dict = None,
-  ):
-    super().__init__()
-    self.dims = dims
-    self.max_position_embeddings = max_position_embeddings
-    self.traditional = traditional
-    self.original_base = base
-    self.scale = scale
-    self.rope_type = rope_type
-    self.rope_scaling = rope_scaling
-    self.base = self.compute_base_freq()
-
-  def compute_base_freq(self):
-    if self.rope_type == "llama3":
-      return self.compute_llama3_base_freq()
-    return self.original_base
-
-  # source: https://github.com/huggingface/transformers/blob/d5a99dfcee6e94065cb7c83cc8ab6fc5daa0cc4e/src/transformers/modeling_rope_utils.py#L318
-  def compute_llama3_base_freq(self):
-    factor = self.rope_scaling["factor"]
-    low_freq_factor = self.rope_scaling.get("low_freq_factor", 1.0)
-    high_freq_factor = self.rope_scaling.get("high_freq_factor", 4.0)
-    old_context_len = self.rope_scaling.get(
-      "original_max_position_embeddings",
-      8192,
-    )
-
-    low_freq_wavelen = old_context_len / low_freq_factor
-    high_freq_wavelen = old_context_len / high_freq_factor
-
-    freqs = self.original_base ** (mx.arange(0, self.dims, 2) / self.dims)
-    wavelens = 2 * mx.pi * freqs
-    new_base_freqs = []
-
-    smooths = (wavelens - high_freq_wavelen) / (low_freq_wavelen - high_freq_wavelen)
-    new_base_freqs = freqs * (1 - smooths) * factor + smooths
-    new_base_freqs = mx.where(wavelens < high_freq_wavelen, freqs, new_base_freqs)
-    new_base_freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, new_base_freqs)
-    return new_base_freqs.mean().item()
-
-  def extra_repr(self):
-    return f"{self.dims}, traditional={self.traditional}, " f"max_position_embeddings={self.max_position_embeddings}, " f"scaling_factor={self.scale}, rope_type={self.rope_type}"
-
-  def __call__(self, x, offset: int = 0):
-    seq_len = x.shape[1] + offset
-    base = self.base
-    if self.max_position_embeddings and seq_len > self.max_position_embeddings:
-      base *= ((self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1)) ** (self.dims / (self.dims - 2))
-
-    return mx.fast.rope(
-      x,
-      self.dims,
-      traditional=self.traditional,
-      base=base,
-      scale=self.scale,
-      offset=offset,
-    )
-
-
-def initialize_rope(args: ModelArgs):
-  head_dim = args.head_dim or args.hidden_size // args.num_attention_heads
-
-  rope_scaling = args.rope_scaling
-  rope_type = "default"
-  rope_scale = 1.0
-
-  if rope_scaling is not None:
-    rope_type = rope_scaling.get("type") or rope_scaling.get("rope_type") or "default"
-    if rope_type == "linear":
-      rope_scale = 1 / rope_scaling["factor"]
-    elif rope_type == "llama3":
-      rope_scale = 1.0  # The scaling is handled internally for llama3
-
-  return DynamicNTKScalingRoPE(
-    dims=head_dim,
-    max_position_embeddings=args.max_position_embeddings,
-    traditional=args.rope_traditional,
-    base=args.rope_theta,
-    scale=rope_scale,
-    rope_type=rope_type,
-    rope_scaling=rope_scaling,
-  )
-
-
-class Attention(nn.Module):
-  def __init__(self, args: ModelArgs):
-    super().__init__()
-
-    dim = args.hidden_size
-    self.n_heads = n_heads = args.num_attention_heads
-    self.n_kv_heads = n_kv_heads = args.num_key_value_heads
-
-    self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads
-
-    self.scale = head_dim**-0.5
-    if hasattr(args, "attention_bias"):
-      attention_bias = args.attention_bias
-    else:
-      attention_bias = False
-
-    self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
-    self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
-    self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
-    self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
-
-    self.rope = initialize_rope(args)
-
-  def __call__(
-    self,
-    x: mx.array,
-    mask: Optional[mx.array] = None,
-    cache: Optional[KVCache] = None,
-  ) -> mx.array:
-    B, L, _D = x.shape
-
-    queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
-
-    # Prepare the queries, keys and values for the attention computation
-    queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
-    keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
-    values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
-
-    if cache is not None:
-      queries = self.rope(queries, offset=cache.offset)
-      keys = self.rope(keys, offset=cache.offset)
-      keys, values = cache.update_and_fetch(keys, values)
-    else:
-      queries = self.rope(queries)
-      keys = self.rope(keys)
-
-    output = mx.fast.scaled_dot_product_attention(queries, keys, values, scale=self.scale, mask=mask)
-    output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
-    return self.o_proj(output)
-
-
-class MLP(nn.Module):
-  def __init__(self, args: ModelArgs):
-    super().__init__()
-
-    dim = args.hidden_size
-    hidden_dim = args.intermediate_size
-    if hasattr(args, "mlp_bias"):
-      mlp_bias = args.mlp_bias
-    else:
-      mlp_bias = False
-
-    self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
-    self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
-    self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
-
-  def __call__(self, x) -> mx.array:
-    return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
-
-
-class TransformerBlock(nn.Module):
-  def __init__(self, args: ModelArgs):
-    super().__init__()
-    self.num_attention_heads = args.num_attention_heads
-    self.hidden_size = args.hidden_size
-    self.self_attn = Attention(args)
-    self.mlp = MLP(args)
-    self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
-    self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
-    self.args = args
-
-  def __call__(
-    self,
-    x: mx.array,
-    mask: Optional[mx.array] = None,
-    cache: Optional[KVCache] = None,
-  ) -> mx.array:
-    r = self.self_attn(self.input_layernorm(x), mask, cache)
-    h = x + r
-    r = self.mlp(self.post_attention_layernorm(h))
-    out = h + r
-    return out
-
-
-class LlamaModel(nn.Module):
-  def __init__(self, args: ModelArgs):
-    super().__init__()
-    self.args = args
-    self.vocab_size = args.vocab_size
-    self.num_hidden_layers = args.num_hidden_layers
-    assert self.vocab_size > 0
-    self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
-    self.layers = [TransformerBlock(args=args) for _ in range(args.shard.end_layer - args.shard.start_layer + 1)]
-    self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
-
-  def __call__(
-    self,
-    inputs: mx.array,
-    cache=None,
-  ):
-    if self.args.shard.is_first_layer():
-      h = self.embed_tokens(inputs)
-    else:
-      h = inputs
-
-    mask = None
-    if h.shape[1] > 1:
-      mask = create_additive_causal_mask(h.shape[1], cache[0].offset if cache is not None else 0)
-      mask = mask.astype(h.dtype)
-
-    if cache is None:
-      cache = [None] * len(self.layers)
-
-    for layer, c in zip(self.layers, cache):
-      h = layer(h, mask, cache=c)
-
-    if self.args.shard.is_last_layer():
-      return self.norm(h)
-    else:
-      return h
-
-
-class Model(nn.Module):
-  def __init__(self, args: ModelArgs):
-    super().__init__()
-    self.args = args
-    self.model_type = args.model_type
-    self.model = LlamaModel(args)
-    if not args.tie_word_embeddings:
-      self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
-
-  def __call__(
-    self,
-    inputs: mx.array,
-    cache=None,
-  ):
-    out = self.model(inputs, cache)
-
-    if self.args.shard.is_last_layer():
-      if self.args.tie_word_embeddings:
-        out = self.model.embed_tokens.as_linear(out)
-      else:
-        out = self.lm_head(out)
-
-    return out
-
-  def sanitize(self, weights):
-    # Remove unused precomputed rotary freqs
-    return {k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k}
-
-  @property
-  def layers(self):
-    return self.model.layers
-
-  @property
-  def head_dim(self):
-    return self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads
-
-  @property
-  def n_kv_heads(self):
-    return self.args.num_key_value_heads

+ 31 - 47
exo/inference/mlx/sharded_utils.py

@@ -27,8 +27,8 @@ class ModelNotFoundError(Exception):
 
 
 MODEL_REMAPPING = {
-  "sharded_mistral": "sharded_llama",  # mistral is compatible with llama
-  "sharded_phi-msft": "sharded_phixtral",
+  "mistral": "llama",  # mistral is compatible with llama
+  "phi-msft": "phixtral",
 }
 
 
@@ -37,10 +37,10 @@ def _get_classes(config: dict):
   Retrieve the model and model args classes based on the configuration.
 
   Args:
-  config (dict): The model configuration.
+   config (dict): The model configuration.
 
   Returns:
-  A tuple containing the Model class and the ModelArgs class.
+   A tuple containing the Model class and the ModelArgs class.
   """
   model_type = config["model_type"]
   model_type = MODEL_REMAPPING.get(model_type, model_type)
@@ -74,26 +74,24 @@ def load_model_shard(
   Load and initialize the model from a given path.
 
   Args:
-  model_path (Path): The path to load the model from.
-  lazy (bool): If False eval the model parameters to make sure they are
-  loaded in memory before returning, otherwise they will be loaded
-  when needed. Default: ``False``
-  model_config(dict, optional): Configuration parameters for the model.
-  Defaults to an empty dictionary.
+   model_path (Path): The path to load the model from.
+   lazy (bool): If False eval the model parameters to make sure they are
+    loaded in memory before returning, otherwise they will be loaded
+    when needed. Default: ``False``
+   model_config(dict, optional): Configuration parameters for the model.
+    Defaults to an empty dictionary.
 
   Returns:
-  nn.Module: The loaded and initialized model.
+   nn.Module: The loaded and initialized model.
 
   Raises:
-  FileNotFoundError: If the weight files (.safetensors) are not found.
-  ValueError: If the model class or args class are not found or cannot be instantiated.
+   FileNotFoundError: If the weight files (.safetensors) are not found.
+   ValueError: If the model class or args class are not found or cannot be instantiated.
   """
-
   config = load_config(model_path)
   config.update(model_config)
 
   # TODO hack
-  config["model_type"] = f"sharded_{config['model_type']}"
   config["shard"] = {
     "model_id": model_path.name,
     "start_layer": shard.start_layer,
@@ -112,11 +110,8 @@ def load_model_shard(
     raise FileNotFoundError(f"No safetensors found in {model_path}")
 
   weights = {}
-  all_weights_keys = set()
   for wf in weight_files:
-    weights_dict = mx.load(wf)
-    all_weights_keys.update(weights_dict.keys())
-    weights.update({k: v for k, v in weights_dict.items() if not k.startswith("model.layers.") or shard.start_layer <= int(k.split(".")[2]) <= shard.end_layer})
+    weights.update(mx.load(wf))
 
   model_class, model_args_class = _get_classes(config=config)
 
@@ -133,18 +128,7 @@ def load_model_shard(
       class_predicate=None,
     )
 
-  filtered_weights = {}
-  for k, v in weights.items():
-    if k.startswith("model.layers."):
-      layer_num = int(k.split(".")[2])
-      if shard.start_layer <= layer_num <= shard.end_layer:
-        new_key = f"model.layers.{layer_num - shard.start_layer}." + ".".join(k.split(".")[3:])
-        filtered_weights[new_key] = v
-    else:
-      filtered_weights[k] = v
-  weights = filtered_weights
-
-  model.load_weights(list(weights.items()), strict=False)
+  model.load_weights(list(weights.items()))
 
   if not lazy:
     mx.eval(model.parameters())
@@ -164,11 +148,11 @@ async def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -
   it is downloaded from the Hugging Face Hub.
 
   Args:
-  path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
-  revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
+   path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
+   revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
 
   Returns:
-  Path: The path to the model.
+   Path: The path to the model.
   """
   model_path = Path(path_or_hf_repo)
   if not model_path.exists():
@@ -210,22 +194,22 @@ async def load_shard(
   Load the model and tokenizer from a given path or a huggingface repository.
 
   Args:
-  path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
-  tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
-  Defaults to an empty dictionary.
-  model_config(dict, optional): Configuration parameters specifically for the model.
-  Defaults to an empty dictionary.
-  adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
-  to the model. Default: ``None``.
-  lazy (bool): If False eval the model parameters to make sure they are
-  loaded in memory before returning, otherwise they will be loaded
-  when needed. Default: ``False``
+   path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
+   tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
+    Defaults to an empty dictionary.
+   model_config(dict, optional): Configuration parameters specifically for the model.
+    Defaults to an empty dictionary.
+   adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
+    to the model. Default: ``None``.
+   lazy (bool): If False eval the model parameters to make sure they are
+    loaded in memory before returning, otherwise they will be loaded
+    when needed. Default: ``False``
   Returns:
-  Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
+   Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
 
   Raises:
-  FileNotFoundError: If config file or safetensors are not found.
-  ValueError: If model class or args class are not found.
+   FileNotFoundError: If config file or safetensors are not found.
+   ValueError: If model class or args class are not found.
   """
   model_path = await get_model_path(path_or_hf_repo)
 

+ 1 - 0
tinychat/examples/tinychat/index.html

@@ -64,6 +64,7 @@
         <option value="llama-3-70b">Llama 3 70B</option>
         <option value="mistral-nemo">Mistral Nemo</option>
         <option value="mistral-large">Mistral Large</option>
+        <option value="deepseek-coder-v2-lite">Deepseek Coder V2 Lite</option>
       </select>
     </div>
     <div class="home centered" x-show="home === 0" x-transition x-effect="