mamba.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. import os, sys, math, argparse, time
  2. sys.path.append(os.getcwd())
  3. from typing import Any, Optional, Dict
  4. from tinygrad import Tensor, TinyJit, nn
  5. from tinygrad.helpers import fetch
  6. from tinygrad.nn.state import load_state_dict, torch_load
  7. from tqdm import tqdm
  8. from transformers import AutoTokenizer
  9. MODELS = {
  10. "130m": {"dim": 768, "n_layers": 24, "vocab_size": 50277, "pad_vocab_size_multiple": 8},
  11. "370m": {"dim": 1024, "n_layers": 48, "vocab_size": 50277, "pad_vocab_size_multiple": 8},
  12. "790m": {"dim": 1536, "n_layers": 48, "vocab_size": 50277, "pad_vocab_size_multiple": 8},
  13. "1.4b": {"dim": 2048, "n_layers": 48, "vocab_size": 50277, "pad_vocab_size_multiple": 8},
  14. "2.8b": {"dim": 2560, "n_layers": 64, "vocab_size": 50277, "pad_vocab_size_multiple": 8},
  15. }
  16. def fetch_weights(model_name: str) -> Dict[str, Tensor]:
  17. if model_name not in MODELS:
  18. raise ValueError(f"Requested unknown mamba model: {model_name}")
  19. downloaded = fetch(f"https://huggingface.co/state-spaces/mamba-{model_name}/resolve/main/pytorch_model.bin?download=true")
  20. return torch_load(downloaded)
  21. def selective_scan_ref(
  22. u,
  23. delta,
  24. A,
  25. B,
  26. C,
  27. D=None,
  28. z=None,
  29. delta_bias=None,
  30. delta_softplus=False,
  31. return_last_state=False,
  32. ):
  33. """
  34. u: r(B D L)
  35. delta: r(B D L)
  36. A: c(D N) or r(D N)
  37. B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
  38. C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
  39. D: r(D)
  40. z: r(B D L)
  41. delta_bias: r(D), fp32
  42. out: r(B D L)
  43. last_state (optional): r(B D dstate) or c(B D dstate)
  44. """
  45. u = u.float()
  46. delta = delta.float()
  47. if delta_bias is not None:
  48. delta = delta + delta_bias[..., None].float()
  49. if delta_softplus:
  50. delta = delta.softplus()
  51. batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
  52. is_variable_B = len(B.shape) >= 3
  53. is_variable_C = len(C.shape) >= 3
  54. x = Tensor.zeros(batch, dim, dstate)
  55. ys = []
  56. deltaA = Tensor.einsum("bdl,dn->bdln", delta, A).exp()
  57. if not is_variable_B:
  58. deltaB_u = Tensor.einsum("bdl,dn,bdl->bdln", delta, B, u)
  59. else:
  60. if len(B.shape) == 3:
  61. deltaB_u = Tensor.einsum("bdl,bnl,bdl->bdln", delta, B, u)
  62. else:
  63. B = B.repeat((1, dim // B.shape[1], 1, 1))
  64. deltaB_u = Tensor.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
  65. if is_variable_C and len(C.shape) == 4:
  66. C = C.repeat((1, dim // C.shape[1], 1, 1))
  67. last_state = None
  68. for i in range(u.shape[2]):
  69. x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
  70. if not is_variable_C:
  71. y = Tensor.einsum("bdn,dn->bd", x, C)
  72. else:
  73. if len(C.shape) == 3:
  74. y = Tensor.einsum("bdn,bn->bd", x, C[:, :, i])
  75. else:
  76. y = Tensor.einsum("bdn,bdn->bd", x, C[:, :, :, i])
  77. if i == u.shape[2] - 1:
  78. last_state = x
  79. ys.append(y)
  80. y = Tensor.stack(*ys, dim=2) # (batch dim L)
  81. out = y if D is None else y + u * D.reshape((-1, 1))
  82. if z is not None:
  83. out = out * z.silu()
  84. return out if not return_last_state else (out, last_state)
  85. class MambaMixer:
  86. def __init__(
  87. self,
  88. dim,
  89. d_state=16,
  90. d_conv=4,
  91. expand=2,
  92. dt_rank="auto",
  93. dt_min=0.001,
  94. dt_max=0.1,
  95. dt_init="random",
  96. dt_scale=1.0,
  97. dt_init_floor=1e-4,
  98. conv_bias=True,
  99. bias=False,
  100. layer_idx=None,
  101. ):
  102. self.dim = dim
  103. self.d_state = d_state
  104. self.d_conv = d_conv
  105. self.expand = expand
  106. self.d_inner = self.expand * self.dim
  107. self.dt_rank = math.ceil(self.dim / 16) if dt_rank == "auto" else dt_rank
  108. self.layer_idx = layer_idx
  109. self.in_proj = nn.Linear(self.dim, self.d_inner * 2, bias=bias)
  110. self.conv1d = nn.Conv1d(in_channels=self.d_inner, out_channels=self.d_inner, bias=conv_bias,
  111. kernel_size=d_conv, groups=self.d_inner, padding=d_conv-1)
  112. self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False)
  113. self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
  114. # Initialize special dt projection to preserve variance at initialization
  115. dt_init_std = self.dt_rank**-0.5 * dt_scale
  116. if dt_init == "constant":
  117. self.dt_proj.weight = Tensor.full(self.dt_proj.weight.shape, dt_init_std)
  118. elif dt_init == "random":
  119. self.dt_proj.weight = Tensor.uniform(self.dt_proj.weight.shape, low=-dt_init_std, high=dt_init_std)
  120. else:
  121. raise NotImplementedError
  122. dt = Tensor.uniform(self.d_inner, low=math.log(dt_min), high=math.log(dt_max)).exp().maximum(dt_init_floor)
  123. inv_dt = dt + (1 - (-dt).exp()).log()
  124. self.dt_proj.bias.assign(inv_dt)
  125. # S4D real initialization
  126. self.A_log = Tensor.arange(1, self.d_state+1).repeat([self.d_inner, 1]).log()
  127. # D "skip" parameter
  128. self.D = Tensor.ones(self.d_inner) # Keep in fp32
  129. self.out_proj = nn.Linear(self.d_inner, self.dim, bias=bias)
  130. def __call__(self, hidden_states: Tensor):
  131. batch, seqlen, _ = hidden_states.shape
  132. if not hasattr(self, 'conv_state'):
  133. self.conv_state = Tensor.zeros(batch, self.dim * self.expand, self.d_conv).contiguous().realize()
  134. self.ssm_state = Tensor.zeros(batch, self.dim * self.expand, self.d_state).realize()
  135. xz = self.in_proj.weight @ hidden_states.permute(2,0,1).reshape(hidden_states.shape[2],hidden_states.shape[1]*hidden_states.shape[0])
  136. xz = xz.reshape(xz.shape[0],xz.shape[1]//seqlen, seqlen).permute(1,0,2)
  137. if self.in_proj.bias is not None:
  138. xz = xz + self.in_proj.bias.reshape((-1, 1))
  139. A = -self.A_log.exp()
  140. x, z = xz.chunk(2, dim=1)
  141. # Compute short convolution
  142. self.conv_state.assign(x[:, :, -self.d_conv :]) # Update state (B D W)
  143. x = self.conv1d(x)[..., :seqlen].swish()
  144. x_dbl = self.x_proj(x.permute(0,2,1).reshape(x.shape[0]*x.shape[2], x.shape[1]))
  145. dt, B, C = Tensor.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
  146. dt = self.dt_proj.weight @ dt.T
  147. dt = dt.reshape(dt.shape[0], dt.shape[1]//seqlen, seqlen).permute(1,0,2)
  148. B = B.reshape(B.shape[0]//seqlen, seqlen, B.shape[1]).permute(0,2,1)
  149. C = C.reshape(C.shape[0]//seqlen, seqlen, C.shape[1]).permute(0,2,1)
  150. # TODO: actually implement selective_scan_fn
  151. y = selective_scan_ref(x, dt, A, B, C, self.D, z=z, delta_bias=self.dt_proj.bias, delta_softplus=True,
  152. return_last_state=True)
  153. y, last_state = y
  154. self.ssm_state.assign(last_state).realize()
  155. y = y.permute(0,2,1)
  156. out = self.out_proj(y)
  157. return out
  158. else:
  159. return self.step(hidden_states)
  160. def step(self, hidden_states: Tensor):
  161. assert hidden_states.shape[1] == 1, f"Only support decoding with 1 token at a time for now, attempted {hidden_states.shape[1]}"
  162. xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
  163. x, z = xz.chunk(2, dim=-1) # (B D)
  164. # Conv step
  165. self.conv_state.assign(self.conv_state[:, :, 1:].cat(x.unsqueeze(-1), dim=-1).realize())
  166. x = (self.conv_state * self.conv1d.weight.squeeze(1)).sum(-1)
  167. if self.conv1d.bias is not None:
  168. x = x + self.conv1d.bias
  169. x = x.swish()
  170. x_db = self.x_proj(x) # (B dt_rank+2*d_state)
  171. dt = x_db[:, : self.dt_rank]
  172. B = x_db[:, self.dt_rank : (self.dt_rank + self.d_state)]
  173. C = x_db[:, (self.dt_rank + self.d_state) :]
  174. # Don't add dt_bias here
  175. dt = self.dt_proj.weight @ dt.T
  176. A = -self.A_log.exp()
  177. # SSM step
  178. dt = (dt + self.dt_proj.bias.unsqueeze(-1)).softplus()
  179. dA = Tensor.einsum("db,dn->bdn", dt, A).exp()
  180. dB = Tensor.einsum("db,bn->bdn", dt, B)
  181. self.ssm_state.assign(self.ssm_state * dA + x.unsqueeze(-1) * dB)
  182. y = Tensor.einsum("bdn,bn->bd", self.ssm_state, C)
  183. y = y + self.D * x
  184. y = y * z.swish() # (B D)
  185. out = self.out_proj(y)
  186. return out.unsqueeze(1)
  187. class MambaBlock:
  188. def __init__(self, dim: int, norm_eps: float = 1e-5, rms_norm: bool = True, layer_idx: Optional[int] = None):
  189. self.mixer = MambaMixer(dim, layer_idx=layer_idx)
  190. if rms_norm:
  191. self.norm = nn.RMSNorm(dim, norm_eps)
  192. else:
  193. raise NotImplementedError
  194. def __call__(self, hidden_states: Tensor, residual: Optional[Tensor] = None):
  195. residual = (hidden_states + residual) if residual is not None else hidden_states
  196. hidden_states = self.norm(residual)
  197. hidden_states = self.mixer(hidden_states)
  198. return hidden_states, residual
  199. class MambaBackbone:
  200. def __init__(self, dim: int, n_layers: int, vocab_size: int, rms_norm: bool = True, norm_eps: float = 1e-5):
  201. self.embedding = nn.Embedding(vocab_size, dim)
  202. self.layers = [MambaBlock(dim, rms_norm=rms_norm, layer_idx=i) for i in range(n_layers)]
  203. if rms_norm:
  204. self.norm_f = nn.RMSNorm(dim, norm_eps)
  205. def __call__(self, input_ids: Tensor) -> Any:
  206. hidden_states = self.embedding(input_ids)
  207. residual = None
  208. for layer in self.layers:
  209. hidden_states, residual = layer(hidden_states, residual)
  210. residual = (hidden_states + residual) if residual is not None else hidden_states
  211. hidden_states = self.norm_f(residual)
  212. return hidden_states
  213. class Mamba:
  214. def __init__(self, dim: int, n_layers: int, vocab_size: int, pad_vocab_size_multiple: int = 1):
  215. if vocab_size % pad_vocab_size_multiple != 0:
  216. vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
  217. self.backbone = MambaBackbone(dim, n_layers, vocab_size)
  218. self.lm_head = nn.Linear(dim, vocab_size, bias=False)
  219. self.forward_jit = TinyJit(self.forward)
  220. def forward(self, input_ids:Tensor):
  221. hidden_states = self.backbone(input_ids)
  222. return self.lm_head(hidden_states).realize()
  223. def __call__(self, input_ids):
  224. return self.forward(input_ids)
  225. @staticmethod
  226. def from_pretrained(model_name: str):
  227. weights = fetch_weights(model_name)
  228. model = Mamba(**MODELS[model_name])
  229. load_state_dict(model, weights)
  230. return model
  231. def generate(model, tokenizer, prompt: str, n_tokens_to_gen: int = 10, temp: bool = 1.0, sample: bool = False, top_k: int = None):
  232. tks = tokenizer(prompt)["input_ids"]
  233. while len(tks) < 4:
  234. tks = [50279] + tks
  235. # Loading in the prompt tokens
  236. logits = model.forward(Tensor([tks]))[:, -1, :]
  237. for _ in tqdm(range(n_tokens_to_gen), desc="Speed Gen"):
  238. # TODO: topk
  239. if sample:
  240. tok_Tens = (logits/temp).softmax().multinomial()
  241. else:
  242. tok_Tens = logits.argmax(axis=-1).unsqueeze(0)
  243. tok = tok_Tens.item()
  244. tks.append(tok)
  245. logits = model.forward_jit(tok_Tens)[:, -1, :]
  246. output_completions = ''.join([tokenizer.decode(output) for output in tks])
  247. return output_completions
  248. if __name__ == "__main__":
  249. ORIG_PROMPT = "Why is gravity "
  250. parser = argparse.ArgumentParser(description="Run Mamba in tinygrad", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  251. parser.add_argument("--prompt", type=str, default="Why is gravity ", help="Prompt for LLM completion")
  252. parser.add_argument("--size", type=str, default="370m",
  253. help=f"Size of model to use [{', '.join([k for k in MODELS.keys()])}]")
  254. parser.add_argument("--n_tokens", type=int, default=10, help="Number of tokens to generate")
  255. parser.add_argument("--sample", dest="sample", action="store_true", help="Sample flag")
  256. parser.add_argument("--temp", type=float, default=1.0, help="Sampling temp has to be <=1.0")
  257. args = parser.parse_args()
  258. tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
  259. model = Mamba.from_pretrained(args.size)
  260. prompt = args.prompt
  261. num_toks = args.n_tokens
  262. sample = args.sample
  263. temp = args.temp
  264. s = time.time()
  265. tinyoutput = generate(model, tokenizer, prompt, n_tokens_to_gen=num_toks, sample=sample, temp=temp)
  266. print(tinyoutput)
  267. print('TIME: ', time.time() - s)
  268. TORCHOUTPUT = "Why is gravity \nso important?\nBecause it's the only"
  269. if ORIG_PROMPT == prompt and not sample and num_toks==10 and args.size=='370m': print('Outputs Match:', tinyoutput == TORCHOUTPUT)