llama.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. from typing import Tuple, Union, Optional, Dict, Any
  2. from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
  3. from tinygrad.helpers import getenv
  4. from exo.inference.shard import Shard
  5. # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
  6. def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half) -> Tensor:
  7. freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
  8. freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
  9. # TODO: move dtype outside this
  10. return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim//2, 2)
  11. # (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
  12. def complex_mult(A, c, d):
  13. a,b = A[..., 0:1], A[..., 1:2]
  14. ro = a*c - b*d
  15. co = a*d + b*c
  16. return ro.cat(co, dim=-1)
  17. def apply_rotary_emb(xq:Tensor, xk:Tensor, freqs_cis:Tensor) -> Tuple[Tensor, Tensor]:
  18. assert freqs_cis.shape[1] == xq.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
  19. xq = xq.reshape(*xq.shape[0:-1], -1, 2)
  20. xk = xk.reshape(*xk.shape[0:-1], -1, 2)
  21. assert len(xq.shape) == len(xk.shape) == len(freqs_cis.shape) == 5
  22. c, d = freqs_cis[..., 0:1], freqs_cis[..., 1:2]
  23. xq_out = complex_mult(xq, c, d)
  24. xk_out = complex_mult(xk, c, d)
  25. return xq_out.flatten(3), xk_out.flatten(3)
  26. def repeat_kv(x:Tensor, n_rep:int) -> Tensor:
  27. bs, seqlen, n_kv_heads, head_dim = x.shape
  28. if n_rep == 1: return x
  29. # NOTE: this is different from x.repeat((1, 1, n_rep, 1))
  30. return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
  31. class Attention:
  32. def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
  33. self.n_heads = n_heads
  34. self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
  35. self.head_dim = dim // n_heads
  36. self.n_rep = self.n_heads // self.n_kv_heads
  37. self.max_context = max_context
  38. self.wq = linear(dim, self.n_heads * self.head_dim, bias=False)
  39. self.wk = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
  40. self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
  41. self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
  42. def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]) -> Tensor:
  43. xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
  44. xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
  45. xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
  46. xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
  47. xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
  48. bsz, seqlen, _, _ = xq.shape
  49. # create kv cache
  50. if not hasattr(self, "cache_kv"):
  51. self.cache_kv = Tensor.zeros(2, bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype).contiguous().realize()
  52. if isinstance(x.device, tuple):
  53. # TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded
  54. self.cache_kv.shard_((x.device), axis=None).realize()
  55. # update the cache
  56. assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}"
  57. self.cache_kv.shrink((None, None, (start_pos, start_pos+seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
  58. keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xk
  59. values = self.cache_kv[1].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xv
  60. keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
  61. xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
  62. attn = xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2)
  63. attn = attn.reshape(bsz, seqlen, -1)
  64. return self.wo(attn)
  65. class FeedForward:
  66. def __init__(self, dim:int, hidden_dim:int, linear=nn.Linear):
  67. self.w1 = linear(dim, hidden_dim, bias=False)
  68. self.w2 = linear(hidden_dim, dim, bias=False)
  69. self.w3 = linear(dim, hidden_dim, bias=False) # the gate in Gated Linear Unit
  70. def __call__(self, x:Tensor) -> Tensor:
  71. return self.w2(self.w1(x).silu() * self.w3(x)) # SwiGLU [arxiv/2002.05202, eq (5)]
  72. class TransformerBlock:
  73. def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int, linear=nn.Linear, feed_forward=FeedForward):
  74. self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear)
  75. self.feed_forward = feed_forward(dim, hidden_dim, linear)
  76. self.attention_norm = nn.RMSNorm(dim, norm_eps)
  77. self.ffn_norm = nn.RMSNorm(dim, norm_eps)
  78. def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]):
  79. h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
  80. return (h + self.feed_forward(self.ffn_norm(h))).contiguous()
  81. # standard openai sampling
  82. def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
  83. assert logits.ndim == 1, "only works on 1d tensors"
  84. assert 0 <= p <= 1, "p must be between 0 and 1"
  85. assert 0 <= k <= logits.numel(), "k must be between 0 and numel"
  86. # if temperature is very low just use argmax
  87. if temp < 1e-6: return logits.argmax()
  88. # alpha sampling
  89. if af or ap:
  90. if not hasattr(sample, "alpha_counter"):
  91. setattr(sample, "alpha_counter", Tensor.zeros_like(logits, dtype=dtypes.int32).contiguous())
  92. logits = logits - (sample.alpha_counter * af + (sample.alpha_counter > 0) * ap)
  93. # replace NaNs with -inf
  94. logits = (logits != logits).where(-float("inf"), logits)
  95. # softmax
  96. t = (logits / temp).softmax()
  97. counter, counter2 = Tensor.arange(t.numel(), device=logits.device).contiguous(), Tensor.arange(t.numel() - 1, -1, -1, device=logits.device).contiguous()
  98. # top k
  99. if k:
  100. output, output_indices = Tensor.zeros(k, device=logits.device).contiguous(), Tensor.zeros(k, device=logits.device, dtype=dtypes.int32).contiguous()
  101. for i in range(k):
  102. t_argmax = (t.numel() - ((t == (t_max := t.max())) * counter2).max() - 1).cast(dtypes.default_int)
  103. output = output + t_max.unsqueeze(0).pad(((i, k - i - 1),))
  104. output_indices = output_indices + t_argmax.unsqueeze(0).pad(((i, k - i - 1),))
  105. t = (counter == t_argmax).where(0, t)
  106. # approximate top p
  107. # because we are already limited to top k elements we can do top p "without sorting"
  108. output_cumsum = output[::-1]._cumsum()[::-1] + t.sum()
  109. output = (output_cumsum >= (1 - p)) * output
  110. output_indices = (output_cumsum >= (1 - p)) * output_indices
  111. # sample
  112. output_idx = output.multinomial()
  113. output_token = output_indices[output_idx]
  114. else:
  115. output_token = t.multinomial()
  116. # increase alpha counter
  117. if af or ap:
  118. sample.alpha_counter = (counter == output_token).where(sample.alpha_counter + 1, sample.alpha_counter)
  119. return output_token
  120. class Transformer:
  121. def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, shard: Shard, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward):
  122. self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(shard.end_layer - shard.start_layer + 1)]
  123. self.norm = nn.RMSNorm(dim, norm_eps)
  124. self.tok_embeddings = nn.Embedding(vocab_size, dim)
  125. self.output = nn.Linear(dim, vocab_size, bias=False)
  126. self.max_context = max_context
  127. self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context * 2, rope_theta)
  128. self.forward_jit = TinyJit(self.forward) if jit else None
  129. self.shard = shard
  130. def forward(self, h:Tensor, start_pos:Union[Variable,int], temperature:float, top_k:int, top_p:float, alpha_f:float, alpha_p:float):
  131. seqlen = h.shape[1]
  132. freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos+seqlen),None,None,None))
  133. if self.shard.is_first_layer():
  134. h = self.tok_embeddings(h)
  135. mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos+1).realize() if seqlen > 1 else None
  136. for i, layer in enumerate(self.layers):
  137. h = layer(h, start_pos, freqs_cis, mask)
  138. # if i == 0 or i == len(self.layers) - 1:
  139. # print(f"layer {i}: {str(h.numpy())[:60]}")
  140. if self.shard.is_last_layer():
  141. logits = self.output(self.norm(h)).float()[:, -1, :]
  142. return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p).realize()
  143. else:
  144. return h.realize()
  145. def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0, top_k:int=0, top_p:float=0.8, alpha_f:float=0.0, alpha_p:float=0.0):
  146. # TODO: better way to handle the first call v.s. the rest?
  147. # if tokens.shape[0:2] == (1,1) and self.forward_jit is not None:
  148. # return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
  149. return self.forward(tokens, start_pos, temperature, top_k, top_p, alpha_f, alpha_p)
  150. def reset(self):
  151. for layer in self.layers:
  152. if hasattr(layer.attention, "cache_kv"):
  153. layer.attention.cache_kv = layer.attention.cache_kv.zeros_like()
  154. # *** helpers ***
  155. def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int, shard: Shard):
  156. def permute(v: Tensor, n_heads: int):
  157. return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
  158. keymap = {
  159. "model.embed_tokens.weight": "tok_embeddings.weight",
  160. **{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))},
  161. **{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
  162. **{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(len(model.layers))},
  163. **{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
  164. "model.norm.weight": "norm.weight",
  165. "lm_head.weight": "output.weight",
  166. }
  167. sd = {}
  168. for k, v in weights.items():
  169. if ".rotary_emb." in k: continue
  170. v = v.to(Device.DEFAULT)
  171. if "model.layers" in k:
  172. layer_num = int(k.split('.')[2])
  173. if shard.start_layer <= layer_num <= shard.end_layer:
  174. k = f"model.layers.{layer_num - shard.start_layer}." + '.'.join(k.split('.')[3:])
  175. else:
  176. continue
  177. if "q_proj" in k:
  178. v = permute(v, n_heads)
  179. elif "k_proj" in k:
  180. v = permute(v, n_kv_heads)
  181. sd[keymap[k]] = v
  182. return sd
  183. def fix_bf16(weights:Dict[Any, Tensor]):
  184. if getenv("SUPPORT_BF16", 1):
  185. # TODO: without casting to float16, 70B llama OOM on tinybox.
  186. return {k:v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}
  187. # TODO: check if device supports bf16
  188. return {k:v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}