1
0

llama.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. from typing import Tuple, Union, Optional, Dict, Any
  2. from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
  3. from tinygrad.helpers import getenv
  4. # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
  5. def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half) -> Tensor:
  6. freqs = 1.0/(theta**(Tensor.arange(0, dim, 2)[:(dim // 2)]/dim))
  7. freqs = Tensor.arange(end).unsqueeze(dim=1)*freqs.unsqueeze(dim=0)
  8. # TODO: move dtype outside this
  9. return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim // 2, 2)
  10. # (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
  11. def complex_mult(A, c, d):
  12. a, b = A[..., 0:1], A[..., 1:2]
  13. ro = a*c - b*d
  14. co = a*d + b*c
  15. return ro.cat(co, dim=-1)
  16. def apply_rotary_emb(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> Tuple[Tensor, Tensor]:
  17. assert freqs_cis.shape[1] == xq.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
  18. xq = xq.reshape(*xq.shape[0:-1], -1, 2)
  19. xk = xk.reshape(*xk.shape[0:-1], -1, 2)
  20. assert len(xq.shape) == len(xk.shape) == len(freqs_cis.shape) == 5
  21. c, d = freqs_cis[..., 0:1], freqs_cis[..., 1:2]
  22. xq_out = complex_mult(xq, c, d)
  23. xk_out = complex_mult(xk, c, d)
  24. return xq_out.flatten(3), xk_out.flatten(3)
  25. def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
  26. bs, seqlen, n_kv_heads, head_dim = x.shape
  27. if n_rep == 1: return x
  28. # NOTE: this is different from x.repeat((1, 1, n_rep, 1))
  29. return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads*n_rep, head_dim)
  30. class Attention:
  31. def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
  32. self.n_heads = n_heads
  33. self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
  34. self.head_dim = dim // n_heads
  35. self.n_rep = self.n_heads // self.n_kv_heads
  36. self.max_context = max_context
  37. self.wq = linear(dim, self.n_heads*self.head_dim, bias=False)
  38. self.wk = linear(dim, self.n_kv_heads*self.head_dim, bias=False)
  39. self.wv = linear(dim, self.n_kv_heads*self.head_dim, bias=False)
  40. self.wo = linear(self.n_heads*self.head_dim, dim, bias=False)
  41. def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]) -> Tensor:
  42. if getenv("WQKV"):
  43. if not hasattr(self, 'wqkv'): self.wqkv = Tensor.cat(self.wq.weight, self.wk.weight, self.wv.weight)
  44. xqkv = x @ self.wqkv.T
  45. xq, xk, xv = xqkv.split([self.wq.weight.shape[0], self.wk.weight.shape[0], self.wv.weight.shape[0]], dim=2)
  46. else:
  47. xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
  48. xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
  49. xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
  50. xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
  51. xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
  52. bsz, seqlen, _, _ = xq.shape
  53. # create kv cache
  54. if not hasattr(self, "cache_kv"):
  55. self.cache_kv = Tensor.zeros(2, bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype).contiguous().realize()
  56. if isinstance(x.device, tuple):
  57. # TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded
  58. self.cache_kv.shard_((x.device), axis=3 if getenv("SHARD_KVCACHE") else None).realize()
  59. # update the cache
  60. assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}"
  61. self.cache_kv.shrink((None, None, (start_pos, start_pos + seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
  62. keys = self.cache_kv[0].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xk
  63. values = self.cache_kv[1].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xv
  64. keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
  65. xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
  66. attn = xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2)
  67. attn = attn.reshape(bsz, seqlen, -1)
  68. return self.wo(attn)
  69. class FeedForward:
  70. def __init__(self, dim: int, hidden_dim: int, linear=nn.Linear):
  71. self.w1 = linear(dim, hidden_dim, bias=False)
  72. self.w2 = linear(hidden_dim, dim, bias=False)
  73. self.w3 = linear(dim, hidden_dim, bias=False) # the gate in Gated Linear Unit
  74. def __call__(self, x: Tensor) -> Tensor:
  75. return self.w2(self.w1(x).silu()*self.w3(x)) # SwiGLU [arxiv/2002.05202, eq (5)]
  76. class TransformerBlock:
  77. def __init__(self, dim: int, hidden_dim: int, n_heads: int, n_kv_heads: int, norm_eps: float, max_context: int, linear=nn.Linear, feed_forward=FeedForward):
  78. self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear)
  79. self.feed_forward = feed_forward(dim, hidden_dim, linear)
  80. self.attention_norm = nn.RMSNorm(dim, norm_eps)
  81. self.ffn_norm = nn.RMSNorm(dim, norm_eps)
  82. def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]):
  83. h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
  84. return (h + self.feed_forward(self.ffn_norm(h))).contiguous()
  85. # standard openai sampling
  86. def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
  87. assert logits.ndim == 1, "only works on 1d tensors"
  88. assert 0 <= p <= 1, "p must be between 0 and 1"
  89. assert 0 <= k <= logits.numel(), "k must be between 0 and numel"
  90. # if temperature is very low just use argmax
  91. if temp < 1e-6: return logits.argmax().reshape(1)
  92. # alpha sampling
  93. if af or ap:
  94. if not hasattr(sample, "alpha_counter"):
  95. setattr(sample, "alpha_counter", Tensor.zeros_like(logits, dtype=dtypes.int32).contiguous())
  96. logits = logits - (sample.alpha_counter*af + (sample.alpha_counter > 0)*ap)
  97. # replace NaNs with -inf
  98. logits = (logits != logits).where(-float("inf"), logits)
  99. # softmax
  100. t = (logits/temp).softmax()
  101. counter, counter2 = Tensor.arange(t.numel(), device=logits.device).contiguous(), Tensor.arange(t.numel() - 1, -1, -1, device=logits.device).contiguous()
  102. # top k
  103. if k:
  104. output, output_indices = Tensor.zeros(k, device=logits.device).contiguous(), Tensor.zeros(k, device=logits.device, dtype=dtypes.int32).contiguous()
  105. for i in range(k):
  106. t_argmax = (t.numel() - ((t == (t_max := t.max()))*counter2).max() - 1).cast(dtypes.default_int)
  107. output = output + t_max.unsqueeze(0).pad(((i, k - i - 1),))
  108. output_indices = output_indices + t_argmax.unsqueeze(0).pad(((i, k - i - 1),))
  109. t = (counter == t_argmax).where(0, t)
  110. # approximate top p
  111. # because we are already limited to top k elements we can do top p "without sorting"
  112. output_cumsum = output[::-1]._cumsum()[::-1] + t.sum()
  113. output = (output_cumsum >= (1 - p))*output
  114. output_indices = (output_cumsum >= (1 - p))*output_indices
  115. # sample
  116. output_idx = output.multinomial()
  117. output_token = output_indices[output_idx]
  118. else:
  119. output_token = t.multinomial()
  120. # increase alpha counter
  121. if af or ap:
  122. sample.alpha_counter = (counter == output_token).where(sample.alpha_counter + 1, sample.alpha_counter)
  123. return output_token
  124. from exo.inference.shard import Shard
  125. class Transformer:
  126. def __init__(
  127. self,
  128. dim: int,
  129. hidden_dim: int,
  130. n_heads: int,
  131. n_layers: int,
  132. norm_eps: float,
  133. vocab_size,
  134. shard: Shard = None,
  135. linear=nn.Linear,
  136. n_kv_heads=None,
  137. rope_theta=10000,
  138. max_context=1024,
  139. jit=True,
  140. feed_forward=FeedForward
  141. ):
  142. self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
  143. self.norm = nn.RMSNorm(dim, norm_eps)
  144. self.tok_embeddings = nn.Embedding(vocab_size, dim)
  145. self.output = nn.Linear(dim, vocab_size, bias=False)
  146. self.max_context = max_context
  147. self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context*2, rope_theta).contiguous()
  148. self.forward_jit = TinyJit(self.forward) if jit else None
  149. self.shard = shard
  150. def forward(self, x: Tensor, start_pos: Union[Variable, int], temperature: float, top_k: int, top_p: float, alpha_f: float, alpha_p: float):
  151. seqlen = x.shape[1]
  152. freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos + seqlen), None, None, None))
  153. mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-100000000"), dtype=x.dtype, device=x.device).triu(start_pos + 1).realize() if seqlen > 1 else None
  154. if self.shard.is_first_layer():
  155. h = self.tok_embeddings(x)
  156. else:
  157. h = x
  158. for i in range(self.shard.start_layer, self.shard.end_layer + 1):
  159. layer = self.layers[i]
  160. h = layer(h, start_pos, freqs_cis, mask)
  161. if self.shard.is_last_layer():
  162. logits = self.output(self.norm(h)).float()[:, -1, :]
  163. return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p).realize()
  164. else:
  165. return h
  166. def __call__(self, tokens: Tensor, start_pos: Variable, temperature: float = 0.0, top_k: int = 0, top_p: float = 0.8, alpha_f: float = 0.0, alpha_p: float = 0.0):
  167. # TODO: better way to handle the first call v.s. the rest?
  168. if tokens.shape[0:2] == (1, 1) and self.forward_jit is not None:
  169. return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
  170. return self.forward(tokens, start_pos, temperature, top_k, top_p, alpha_f, alpha_p)
  171. # *** helpers ***
  172. def convert_from_huggingface(weights: Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int):
  173. def permute(v: Tensor, n_heads: int):
  174. return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
  175. keymap = {
  176. "model.embed_tokens.weight": "tok_embeddings.weight",
  177. **{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight"
  178. for l in range(len(model.layers))},
  179. **{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight"
  180. for x in ["q", "k", "v", "o"]
  181. for l in range(len(model.layers))},
  182. **{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight"
  183. for l in range(len(model.layers))},
  184. **{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight"
  185. for x, y in {"gate": "1", "down": "2", "up": "3"}.items()
  186. for l in range(len(model.layers))},
  187. "model.norm.weight": "norm.weight",
  188. "lm_head.weight": "output.weight",
  189. }
  190. sd = {}
  191. for k, v in weights.items():
  192. if ".rotary_emb." in k: continue
  193. v = v.to(Device.DEFAULT)
  194. if "model.layers" in k:
  195. if "q_proj" in k:
  196. v = permute(v, n_heads)
  197. elif "k_proj" in k:
  198. v = permute(v, n_kv_heads)
  199. sd[keymap[k]] = v
  200. return sd
  201. def fix_bf16(weights: Dict[Any, Tensor]):
  202. if getenv("SUPPORT_BF16", 1):
  203. # TODO: without casting to float16, 70B llama OOM on tinybox.
  204. return {k: v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
  205. # TODO: check if device supports bf16
  206. return {k: v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}