llama.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. from dataclasses import dataclass, field
  2. import mlx.core as mx
  3. import mlx.nn as nn
  4. from mlx_lm.models.base import create_attention_mask
  5. from mlx_lm.models.llama import TransformerBlock, ModelArgs
  6. from ...shard import Shard
  7. from .base import IdentityBlock
  8. @dataclass
  9. class ModelArgs(ModelArgs):
  10. shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
  11. def __post_init__(self):
  12. super().__post_init__() # Ensure parent initializations are respected
  13. if isinstance(self.shard, Shard):
  14. return
  15. if not isinstance(self.shard, dict):
  16. raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
  17. self.shard = Shard(**self.shard)
  18. class LlamaModel(nn.Module):
  19. def __init__(self, args: ModelArgs):
  20. super().__init__()
  21. self.args = args
  22. self.vocab_size = args.vocab_size
  23. self.num_hidden_layers = args.num_hidden_layers
  24. assert self.vocab_size > 0
  25. if args.shard.is_first_layer() or (args.shard.is_last_layer() and args.tie_word_embeddings):
  26. self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
  27. self.layers = []
  28. for i in range(self.num_hidden_layers):
  29. if args.shard.start_layer <= i <= args.shard.end_layer:
  30. self.layers.append(TransformerBlock(args=args))
  31. else:
  32. self.layers.append(IdentityBlock())
  33. if args.shard.is_last_layer():
  34. self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
  35. def __call__(
  36. self,
  37. inputs: mx.array,
  38. cache=None,
  39. ):
  40. if self.args.shard.is_first_layer():
  41. h = self.embed_tokens(inputs)
  42. else:
  43. h = inputs
  44. mask = None
  45. if h.ndim > 1 and h.shape[1] > 1:
  46. mask = create_attention_mask(h, cache)
  47. if cache is None:
  48. cache = [None]*len(self.layers)
  49. for layer, c in zip(self.layers, cache):
  50. h = layer(h, mask, cache=c)
  51. if self.args.shard.is_last_layer():
  52. h = self.norm(h)
  53. return h
  54. class Model(nn.Module):
  55. def __init__(self, args: ModelArgs):
  56. super().__init__()
  57. self.args = args
  58. self.model_type = args.model_type
  59. self.model = LlamaModel(args)
  60. if args.shard.is_last_layer():
  61. if not args.tie_word_embeddings:
  62. self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
  63. def __call__(
  64. self,
  65. inputs: mx.array,
  66. cache=None,
  67. ):
  68. out = self.model(inputs, cache)
  69. if self.args.shard.is_last_layer():
  70. if self.args.tie_word_embeddings:
  71. out = self.model.embed_tokens.as_linear(out)
  72. else:
  73. out = self.lm_head(out)
  74. return out
  75. def sanitize(self, weights):
  76. shard_state_dict = {}
  77. for key, value in weights.items():
  78. if "self_attn.rotary_emb.inv_freq" in key:
  79. continue
  80. if key.startswith('model.layers.'):
  81. layer_num = int(key.split('.')[2])
  82. if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
  83. shard_state_dict[key] = value
  84. elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'):
  85. shard_state_dict[key] = value
  86. elif (self.args.shard.is_last_layer() and self.args.tie_word_embeddings) and key.startswith('model.embed_tokens'):
  87. shard_state_dict[key] = value
  88. elif (self.args.shard.is_last_layer() and not self.args.tie_word_embeddings) and key.startswith('lm_head'):
  89. shard_state_dict[key] = value
  90. elif self.args.shard.is_last_layer() and (key.startswith('model.norm')):
  91. shard_state_dict[key] = value
  92. return shard_state_dict
  93. @property
  94. def layers(self):
  95. return self.model.layers
  96. @property
  97. def head_dim(self):
  98. return (self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads)
  99. @property
  100. def n_kv_heads(self):
  101. return self.args.num_key_value_heads