sharded_llava.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617
  1. # Copyright © 2024 Apple Inc.
  2. import math
  3. import inspect
  4. from dataclasses import dataclass, field
  5. from typing import Optional, Dict, Union
  6. import mlx.core as mx
  7. import mlx.nn as nn
  8. from mlx_lm.models.base import BaseModelArgs, KVCache
  9. from exo.inference.shard import Shard
  10. import numpy as np
  11. @dataclass
  12. class VisionConfig:
  13. model_type: str
  14. num_hidden_layers: int = 24
  15. hidden_size: int = 1024
  16. intermediate_size: int = 4096
  17. num_attention_heads: int = 16
  18. image_size: int = 336
  19. patch_size: int = 14
  20. projection_dim: int = 768
  21. vocab_size: int = 32000
  22. num_channels: int = 3
  23. layer_norm_eps: float = 1e-5
  24. @classmethod
  25. def from_dict(cls, params):
  26. return cls(
  27. **{
  28. k: v
  29. for k, v in params.items()
  30. if k in inspect.signature(cls).parameters
  31. }
  32. )
  33. class VisionAttention(nn.Module):
  34. def __init__(
  35. self,
  36. dims: int,
  37. num_heads: int,
  38. query_input_dims: Optional[int] = None,
  39. key_input_dims: Optional[int] = None,
  40. value_input_dims: Optional[int] = None,
  41. value_dims: Optional[int] = None,
  42. value_output_dims: Optional[int] = None,
  43. bias: bool = False,
  44. ):
  45. super().__init__()
  46. if (dims % num_heads) != 0:
  47. raise ValueError(
  48. "The input feature dimensions should be divisible by the "
  49. f"number of heads ({dims} % {num_heads}) != 0"
  50. )
  51. query_input_dims = query_input_dims or dims
  52. key_input_dims = key_input_dims or dims
  53. value_input_dims = value_input_dims or key_input_dims
  54. value_dims = value_dims or dims
  55. value_output_dims = value_output_dims or dims
  56. self.num_heads = num_heads
  57. self.q_proj = nn.Linear(query_input_dims, dims, bias=bias)
  58. self.k_proj = nn.Linear(key_input_dims, dims, bias=bias)
  59. self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias)
  60. self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias)
  61. def __call__(self, queries, keys, values, mask=None):
  62. queries = self.q_proj(queries)
  63. keys = self.k_proj(keys)
  64. values = self.v_proj(values)
  65. num_heads = self.num_heads
  66. B, L, D = queries.shape
  67. _, S, _ = keys.shape
  68. queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
  69. keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
  70. values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
  71. scale = math.sqrt(1 / queries.shape[-1])
  72. scores = (queries * scale) @ keys
  73. if mask is not None:
  74. scores = scores + mask.astype(scores.dtype)
  75. scores = mx.softmax(scores, axis=-1)
  76. values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
  77. return self.out_proj(values_hat)
  78. class VisionMLP(nn.Module):
  79. def __init__(self, config: VisionConfig):
  80. super().__init__()
  81. self.activation_fn = nn.GELU(approx="fast")
  82. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  83. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  84. def __call__(self, x: mx.array) -> mx.array:
  85. x = self.activation_fn(self.fc1(x))
  86. x = self.fc2(x)
  87. return x
  88. class VisionEncoderLayer(nn.Module):
  89. def __init__(self, config: VisionConfig):
  90. super().__init__()
  91. self.embed_dim = config.hidden_size
  92. self.self_attn = VisionAttention(
  93. config.hidden_size, config.num_attention_heads, bias=True
  94. )
  95. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  96. self.mlp = VisionMLP(config)
  97. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  98. def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
  99. y = self.layer_norm1(x)
  100. y = self.self_attn(y, y, y, mask)
  101. x = x + y
  102. y = self.layer_norm2(x)
  103. y = self.mlp(y)
  104. return x + y
  105. class VisionEncoder(nn.Module):
  106. def __init__(self, config: VisionConfig):
  107. super().__init__()
  108. self.layers = [VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]
  109. class VisionEmbeddings(nn.Module):
  110. def __init__(self, config: VisionConfig):
  111. super().__init__()
  112. self.config = config
  113. self.embed_dim = config.hidden_size
  114. self.image_size = config.image_size
  115. self.patch_size = config.patch_size
  116. self.class_embedding = mx.zeros((config.hidden_size,))
  117. self.patch_embedding = nn.Conv2d(
  118. in_channels=config.num_channels,
  119. out_channels=self.embed_dim,
  120. kernel_size=self.patch_size,
  121. stride=self.patch_size,
  122. bias=False,
  123. )
  124. self.num_patches = (self.image_size // self.patch_size) ** 2
  125. self.num_positions = self.num_patches + 1
  126. self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
  127. def __call__(self, x: mx.array) -> mx.array:
  128. batch_size = x.shape[0]
  129. patch_embeddings = self.patch_embedding(x)
  130. patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2)
  131. embed_dim = patch_embeddings.shape[-1]
  132. cls_embeddings = mx.broadcast_to(
  133. self.class_embedding, (batch_size, 1, embed_dim)
  134. )
  135. embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1)
  136. embeddings += self.position_embedding.weight
  137. return embeddings
  138. class ClipVisionModel(nn.Module):
  139. def __init__(self, config: VisionConfig):
  140. super().__init__()
  141. self.embeddings = VisionEmbeddings(config)
  142. self.pre_layrnorm = nn.LayerNorm(config.hidden_size)
  143. self.encoder = VisionEncoder(config)
  144. self.post_layernorm = nn.LayerNorm(config.hidden_size)
  145. def __call__(
  146. self,
  147. x: mx.array,
  148. output_hidden_states: Optional[bool] = None,
  149. ) -> mx.array:
  150. x = self.embeddings(x)
  151. x = self.pre_layrnorm(x)
  152. encoder_states = (x,) if output_hidden_states else None
  153. for l in self.encoder.layers:
  154. x = l(x, mask=None)
  155. if output_hidden_states:
  156. encoder_states = encoder_states + (x,)
  157. pooler_output = self.post_layernorm(x[:, 0, :])
  158. return pooler_output, x, encoder_states
  159. class VisionModel(nn.Module):
  160. def __init__(self, config: VisionConfig):
  161. super().__init__()
  162. self.model_type = config.model_type
  163. if self.model_type != "clip_vision_model":
  164. raise ValueError(f"Unsupported model type: {self.model_type}")
  165. self.vision_model = ClipVisionModel(config)
  166. def __call__(
  167. self, x: mx.array, output_hidden_states: Optional[bool] = None
  168. ) -> mx.array:
  169. return self.vision_model(x, output_hidden_states)
  170. @staticmethod
  171. def sanitize(weights):
  172. sanitized_weights = {}
  173. for k, v in weights.items():
  174. if "position_ids" in k:
  175. # Remove unused position_ids
  176. continue
  177. elif "patch_embedding.weight" in k:
  178. # PyTorch conv2d weight tensors have shape:
  179. # [out_channels, in_channels, kH, KW]
  180. # MLX conv2d expects the weight be of shape:
  181. # [out_channels, kH, KW, in_channels]
  182. sanitized_weights[k] = v.transpose(0, 2, 3, 1)
  183. else:
  184. sanitized_weights[k] = v
  185. return sanitized_weights
  186. @dataclass
  187. class TextConfig:
  188. model_type: str
  189. hidden_size: int = 4096
  190. num_hidden_layers: int = 32
  191. intermediate_size: int = 11008
  192. num_attention_heads: int = 32
  193. head_dim: int = None
  194. rms_norm_eps: float = 1e-6
  195. vocab_size: int = 32000
  196. num_key_value_heads: int = None
  197. rope_theta: float = 10000
  198. rope_traditional: bool = False
  199. rope_scaling: Optional[Dict[str, Union[float, str]]] = None
  200. @classmethod
  201. def from_dict(cls, params):
  202. return cls(
  203. **{
  204. k: v
  205. for k, v in params.items()
  206. if k in inspect.signature(cls).parameters
  207. }
  208. )
  209. def __post_init__(self):
  210. if self.num_key_value_heads is None:
  211. self.num_key_value_heads = self.num_attention_heads
  212. if self.head_dim is None:
  213. self.head_dim = self.hidden_size // self.num_attention_heads
  214. if self.model_type is None:
  215. self.model_type = "llama"
  216. if self.rope_scaling:
  217. required_keys = {"factor", "type"}
  218. if not all(key in self.rope_scaling for key in required_keys):
  219. raise ValueError(f"rope_scaling must contain keys {required_keys}")
  220. if self.rope_scaling["type"] != "linear":
  221. raise ValueError("rope_scaling 'type' currently only supports 'linear'")
  222. class TextAttention(nn.Module):
  223. def __init__(self, config: TextConfig):
  224. super().__init__()
  225. dim = config.hidden_size
  226. self.n_heads = n_heads = config.num_attention_heads
  227. self.n_kv_heads = n_kv_heads = config.num_key_value_heads
  228. self.repeats = n_heads // n_kv_heads
  229. head_dim = config.hidden_size // n_heads
  230. self.scale = head_dim ** -0.5
  231. self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
  232. self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
  233. self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
  234. self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
  235. rope_scale = (
  236. 1 / config.rope_scaling["factor"]
  237. if config.rope_scaling is not None
  238. and config.rope_scaling["type"] == "linear"
  239. else 1
  240. )
  241. self.rope = nn.RoPE(
  242. head_dim,
  243. traditional=config.rope_traditional,
  244. base=config.rope_theta,
  245. scale=rope_scale,
  246. )
  247. def __call__(
  248. self,
  249. x: mx.array,
  250. mask: Optional[mx.array] = None,
  251. cache: Optional[KVCache] = None,
  252. ) -> mx.array:
  253. B, L, D = x.shape
  254. queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
  255. # Prepare the queries, keys and values for the attention computation
  256. queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
  257. keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
  258. values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
  259. if cache is not None:
  260. queries = self.rope(queries, offset=cache.offset)
  261. keys = self.rope(keys, offset=cache.offset)
  262. keys, values = cache.update_and_fetch(keys, values)
  263. else:
  264. queries = self.rope(queries)
  265. keys = self.rope(keys)
  266. output = mx.fast.scaled_dot_product_attention(
  267. queries, keys, values, scale=self.scale, mask=mask
  268. )
  269. output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
  270. return self.o_proj(output)
  271. class TextMLP(nn.Module):
  272. def __init__(self, dim, hidden_dim):
  273. super().__init__()
  274. self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
  275. self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
  276. self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
  277. def __call__(self, x) -> mx.array:
  278. return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
  279. class TransformerBlock(nn.Module):
  280. def __init__(self, config: TextConfig):
  281. super().__init__()
  282. self.num_attention_heads = config.num_attention_heads
  283. self.hidden_size = config.hidden_size
  284. self.self_attn = TextAttention(config)
  285. self.mlp = TextMLP(config.hidden_size, config.intermediate_size)
  286. self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  287. self.post_attention_layernorm = nn.RMSNorm(
  288. config.hidden_size, eps=config.rms_norm_eps
  289. )
  290. self.config = config
  291. def __call__(
  292. self,
  293. x: mx.array,
  294. mask: Optional[mx.array] = None,
  295. cache: Optional[KVCache] = None,
  296. ) -> mx.array:
  297. r = self.self_attn(self.input_layernorm(x), mask, cache)
  298. h = x + r
  299. r = self.mlp(self.post_attention_layernorm(h))
  300. out = h + r
  301. return out
  302. class Llama(nn.Module):
  303. def __init__(self, config: TextConfig, is_first_layer, is_last_layer):
  304. super().__init__()
  305. self.config = config
  306. self.is_first_layer = is_first_layer
  307. self.is_last_layer = is_last_layer
  308. self.vocab_size = config.vocab_size
  309. self.model_type = config.model_type
  310. self.num_hidden_layers = config.num_hidden_layers
  311. self.num_key_value_heads = config.num_key_value_heads
  312. self.head_dim = config.head_dim
  313. assert self.vocab_size > 0
  314. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
  315. self.layers = [
  316. TransformerBlock(config=config) for _ in range(config.num_hidden_layers)
  317. ]
  318. self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  319. def __call__(
  320. self,
  321. inputs: mx.array,
  322. cache=None,
  323. inputs_embeds=None,
  324. ):
  325. # for passing merged input embeddings
  326. if inputs_embeds is None:
  327. if self.is_first_layer:
  328. h = self.embed_tokens(inputs)
  329. else:
  330. h = inputs
  331. else:
  332. h = inputs_embeds
  333. mask = None
  334. if h.shape[1] > 1:
  335. mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
  336. mask = mask.astype(h.dtype)
  337. if cache is None:
  338. cache = [None] * len(self.layers)
  339. for layer, c in zip(self.layers, cache):
  340. h = layer(h, mask, c)
  341. if self.is_last_layer:
  342. h = self.norm(h)
  343. return h
  344. class LanguageModel(nn.Module):
  345. def __init__(self, config: TextConfig, is_first_layer, is_last_layer):
  346. super().__init__()
  347. self.model_type = config.model_type
  348. if self.model_type != "llama":
  349. raise ValueError(
  350. f"Model type {self.model_type} not supported. Currently only 'llama' is supported"
  351. )
  352. self.is_last_layer = is_last_layer
  353. self.model = Llama(config, is_first_layer, is_last_layer)
  354. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  355. def __call__(
  356. self,
  357. inputs: mx.array,
  358. cache=None,
  359. inputs_embeds=None,
  360. ):
  361. out = self.model(inputs, cache, inputs_embeds)
  362. if self.is_last_layer:
  363. out = self.lm_head(out)
  364. return out
  365. @staticmethod
  366. def sanitize(weights):
  367. # Remove unused precomputed rotary freqs
  368. return {
  369. k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
  370. }
  371. @dataclass
  372. class LlaVAConfig(BaseModelArgs):
  373. text_config: TextConfig
  374. vision_config: VisionConfig = None
  375. model_type: str = "llava"
  376. ignore_index: int = -100
  377. image_token_index: int = 32000
  378. vision_feature_select_strategy: str = "default"
  379. vision_feature_layer: int = -2
  380. vocab_size: int = 32000
  381. @classmethod
  382. def from_dict(cls, params):
  383. updated_params = {}
  384. class_params = inspect.signature(cls).parameters
  385. for k, v in params.items():
  386. if k in class_params:
  387. if k in ["text_config", "vision_config"]:
  388. v = class_params[k].annotation.from_dict(v)
  389. updated_params.update({k: v})
  390. return cls(**updated_params)
  391. @dataclass
  392. class ModelArgs(LlaVAConfig):
  393. shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
  394. def __post_init__(self):
  395. if isinstance(self.shard, dict):
  396. self.shard = Shard(**self.shard)
  397. if not isinstance(self.shard, Shard):
  398. raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
  399. if not self.shard.is_first_layer():
  400. self.vision_config = None
  401. self.text_config.num_hidden_layers = self.shard.get_layer_count()
  402. class LlavaMultiModalProjector(nn.Module):
  403. def __init__(self, config: LlaVAConfig):
  404. super().__init__()
  405. self.linear_1 = nn.Linear(
  406. config.vision_config.hidden_size, config.text_config.hidden_size, bias=True
  407. )
  408. self.gelu = nn.GELU()
  409. self.linear_2 = nn.Linear(
  410. config.text_config.hidden_size, config.text_config.hidden_size, bias=True
  411. )
  412. def __call__(self, x: mx.array) -> mx.array:
  413. x = self.linear_1(x)
  414. x = self.gelu(x)
  415. x = self.linear_2(x)
  416. return x
  417. class Model(nn.Module):
  418. def __init__(self, config: ModelArgs):
  419. super().__init__()
  420. self.config = config
  421. self.model_type = config.model_type
  422. if config.vision_config:
  423. self.vision_tower = VisionModel(config.vision_config)
  424. self.multi_modal_projector = LlavaMultiModalProjector(config)
  425. self.vision_feature_layer = config.vision_feature_layer
  426. self.vision_feature_select_strategy = config.vision_feature_select_strategy
  427. self.language_model = LanguageModel(config.text_config, config.shard.is_first_layer(), config.shard.is_last_layer())
  428. def get_input_embeddings(
  429. self,
  430. input_ids: Optional[mx.array] = None,
  431. pixel_values: Optional[mx.array] = None,
  432. ):
  433. if pixel_values is None:
  434. return self.language_model(input_ids)
  435. # Get the input embeddings from the language model
  436. inputs_embeds = self.language_model.model.embed_tokens(input_ids)
  437. # Get the ouptut hidden states from the vision model
  438. *_, hidden_states = self.vision_tower(
  439. pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True
  440. )
  441. # Select the hidden states from the desired layer
  442. selected_image_feature = hidden_states[self.vision_feature_layer]
  443. if self.vision_feature_select_strategy == "default":
  444. selected_image_feature = selected_image_feature[:, 1:]
  445. elif self.vision_feature_select_strategy == "full":
  446. selected_image_feature = selected_image_feature
  447. else:
  448. raise ValueError(
  449. "Unexpected feature selection strategy: "
  450. f"{self.vision_feature_select_strategy}"
  451. )
  452. # Pass image features through the multi-modal projector
  453. image_features = self.multi_modal_projector(selected_image_feature)
  454. # Insert special image tokens in the input_ids
  455. final_inputs_embeds = self._merge_input_ids_with_image_features(
  456. image_features, inputs_embeds, input_ids
  457. )
  458. return final_inputs_embeds
  459. def _merge_input_ids_with_image_features(
  460. self, image_features, inputs_embeds, input_ids
  461. ):
  462. image_token_index = self.config.image_token_index
  463. num_images, num_image_patches, embed_dim = image_features.shape
  464. # Positions of <image> tokens in input_ids, assuming batch size is 1
  465. image_positions = np.where(input_ids[0] == image_token_index)[0].tolist()
  466. if len(image_positions) != num_images:
  467. raise ValueError(
  468. f"The number of image tokens ({len(image_positions)}) does not "
  469. f" match the number of image inputs ({num_images})."
  470. )
  471. text_segments = []
  472. start_idx = 0
  473. for position in image_positions:
  474. text_segments.append(inputs_embeds[:, start_idx:position])
  475. start_idx = position + 1
  476. image_embeddings = mx.split(image_features, image_features.shape[0])
  477. final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p]
  478. final_embeddings += [inputs_embeds[:, start_idx:]]
  479. # Create a final embedding of shape
  480. # (1, num_image_patches*num_images + sequence_len, embed_dim)
  481. return mx.concatenate(final_embeddings, axis=1)
  482. def __call__(self, input_ids: mx.array, pixel_values: mx.array = None, cache=None):
  483. input_embddings = None
  484. if pixel_values is not None:
  485. input_embddings = self.get_input_embeddings(input_ids, pixel_values)
  486. logits = self.language_model(
  487. input_ids, cache=cache, inputs_embeds=input_embddings
  488. )
  489. return logits
  490. def sanitize(self, weights):
  491. if self.config.vision_config:
  492. weights = self.vision_tower.sanitize(weights)
  493. weights = self.language_model.sanitize(weights)
  494. return weights
  495. @property
  496. def layers(self):
  497. return self.language_model.model.layers
  498. @property
  499. def head_dim(self):
  500. return (
  501. self.language_model.model.head_dim or self.language_model.model.hidden_size // self.language_model.model.num_attention_heads
  502. )
  503. @property
  504. def n_kv_heads(self):
  505. return self.language_model.model.num_key_value_heads