phi3.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. from dataclasses import dataclass, field
  2. import mlx.core as mx
  3. import mlx.nn as nn
  4. from mlx_lm.models.base import create_attention_mask
  5. from mlx_lm.models.phi3 import TransformerBlock, ModelArgs
  6. from ...shard import Shard
  7. from .base import IdentityBlock
  8. @dataclass
  9. class ModelArgs(ModelArgs):
  10. shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
  11. def __post_init__(self):
  12. super().__post_init__()
  13. if isinstance(self.shard, Shard):
  14. return
  15. if not isinstance(self.shard, dict):
  16. raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
  17. self.shard = Shard(**self.shard)
  18. class Phi3Model(nn.Module):
  19. def __init__(self, args: ModelArgs):
  20. super().__init__()
  21. self.args = args
  22. self.vocab_size = args.vocab_size
  23. self.num_hidden_layers = args.num_hidden_layers
  24. assert self.vocab_size > 0
  25. if self.args.shard.is_first_layer():
  26. self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
  27. self.layers = []
  28. for i in range(self.num_hidden_layers):
  29. if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
  30. self.layers.append(TransformerBlock(args=args))
  31. else:
  32. self.layers.append(IdentityBlock())
  33. if self.args.shard.is_last_layer():
  34. self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
  35. def __call__(
  36. self,
  37. inputs: mx.array,
  38. cache=None,
  39. ):
  40. if self.args.shard.is_first_layer():
  41. h = self.embed_tokens(inputs)
  42. else:
  43. h = inputs
  44. mask = None
  45. if h.shape[1] > 1:
  46. mask = create_attention_mask(h, cache)
  47. if cache is None:
  48. cache = [None] * len(self.layers)
  49. for layer, c in zip(self.layers, cache):
  50. h = layer(h, mask, c)
  51. if self.args.shard.is_last_layer():
  52. h = self.norm(h)
  53. return h
  54. class Model(nn.Module):
  55. def __init__(self, args: ModelArgs):
  56. super().__init__()
  57. self.args = args
  58. self.model_type = args.model_type
  59. self.model = Phi3Model(args)
  60. if self.args.shard.is_last_layer():
  61. self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
  62. def __call__(
  63. self,
  64. inputs: mx.array,
  65. cache=None,
  66. ):
  67. out = self.model(inputs, cache)
  68. if self.args.shard.is_last_layer():
  69. out = self.lm_head(out)
  70. return out
  71. def sanitize(self, weights):
  72. shard_state_dict = {}
  73. for key, value in weights.items():
  74. if "self_attn.rope.inv_freq" in key:
  75. continue
  76. if key.startswith('model.layers.'):
  77. layer_num = int(key.split('.')[2])
  78. if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
  79. shard_state_dict[key] = value
  80. elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'):
  81. shard_state_dict[key] = value
  82. elif self.args.shard.is_last_layer() and (key.startswith('lm_head') or key.startswith('model.norm')):
  83. shard_state_dict[key] = value
  84. return shard_state_dict
  85. @property
  86. def layers(self):
  87. return self.model.layers
  88. @property
  89. def head_dim(self):
  90. return self.args.hidden_size // self.args.num_attention_heads
  91. @property
  92. def n_kv_heads(self):
  93. return self.args.num_key_value_heads