deepseek_v2.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. from dataclasses import dataclass, field
  2. from typing import Optional
  3. import mlx.core as mx
  4. import mlx.nn as nn
  5. from mlx_lm.models.base import KVCache
  6. from mlx_lm.models.deepseek_v2 import ModelArgs, DeepseekV2DecoderLayer
  7. from .base import IdentityBlock
  8. from exo.inference.shard import Shard
  9. @dataclass
  10. class ModelArgs(ModelArgs):
  11. shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
  12. def __post_init__(self):
  13. if isinstance(self.shard, Shard):
  14. return
  15. if not isinstance(self.shard, dict):
  16. raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
  17. self.shard = Shard(**self.shard)
  18. class DeepseekV2Model(nn.Module):
  19. def __init__(self, config: ModelArgs):
  20. super().__init__()
  21. self.args = config
  22. self.num_hidden_layers = config.num_hidden_layers
  23. self.vocab_size = config.vocab_size
  24. if self.args.shard.is_first_layer():
  25. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
  26. self.layers = []
  27. for i in range(self.num_hidden_layers):
  28. if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
  29. self.layers.append(DeepseekV2DecoderLayer(config, i))
  30. else:
  31. self.layers.append(IdentityBlock())
  32. if self.args.shard.is_last_layer():
  33. self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  34. def __call__(
  35. self,
  36. x: mx.array,
  37. cache: Optional[KVCache] = None,
  38. ) -> mx.array:
  39. if self.args.shard.is_first_layer():
  40. h = self.embed_tokens(x)
  41. else:
  42. h = x
  43. mask = None
  44. T = h.shape[1]
  45. if T > 1:
  46. mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
  47. mask = mask.astype(h.dtype)
  48. if cache is None:
  49. cache = [None] * len(self.layers)
  50. for layer, c in zip(self.layers, cache):
  51. h = layer(h, mask, c)
  52. if self.args.shard.is_last_layer():
  53. h = self.norm(h)
  54. return h
  55. class Model(nn.Module):
  56. def __init__(self, config: ModelArgs):
  57. super().__init__()
  58. self.args = config
  59. self.model_type = config.model_type
  60. self.model = DeepseekV2Model(config)
  61. if self.args.shard.is_last_layer():
  62. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  63. def __call__(
  64. self,
  65. inputs: mx.array,
  66. cache: Optional[KVCache] = None,
  67. ):
  68. out = self.model(inputs, cache)
  69. if self.args.shard.is_last_layer():
  70. return self.lm_head(out)
  71. return out
  72. def sanitize(self, weights):
  73. shard_state_dict = {}
  74. for key, value in weights.items():
  75. if key.startswith('model.layers.'):
  76. layer_num = int(key.split('.')[2])
  77. if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
  78. shard_state_dict[key] = value
  79. elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'):
  80. shard_state_dict[key] = value
  81. elif self.args.shard.is_last_layer() and (key.startswith('model.norm') or key.startswith('lm_head')):
  82. shard_state_dict[key] = value
  83. for l in range(self.args.num_hidden_layers):
  84. prefix = f"model.layers.{l}"
  85. for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
  86. for k in ["weight", "scales", "biases"]:
  87. if f"{prefix}.mlp.experts.0.{m}.{k}" in shard_state_dict:
  88. to_join = [shard_state_dict.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}") for e in range(self.args.n_routed_experts)]
  89. shard_state_dict[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
  90. return shard_state_dict
  91. @property
  92. def layers(self):
  93. return self.model.layers
  94. @property
  95. def head_dim(self):
  96. return (
  97. self.args.qk_nope_head_dim + self.args.qk_rope_head_dim,
  98. self.args.v_head_dim,
  99. )
  100. @property
  101. def n_kv_heads(self):
  102. return self.args.num_key_value_heads