| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316 |
- import os, sys, math, argparse, time
- sys.path.append(os.getcwd())
- from typing import Any, Optional, Dict
- from tinygrad import Tensor, TinyJit, nn
- from tinygrad.helpers import fetch
- from tinygrad.nn.state import load_state_dict, torch_load
- from tqdm import tqdm
- from transformers import AutoTokenizer
- MODELS = {
- "130m": {"dim": 768, "n_layers": 24, "vocab_size": 50277, "pad_vocab_size_multiple": 8},
- "370m": {"dim": 1024, "n_layers": 48, "vocab_size": 50277, "pad_vocab_size_multiple": 8},
- "790m": {"dim": 1536, "n_layers": 48, "vocab_size": 50277, "pad_vocab_size_multiple": 8},
- "1.4b": {"dim": 2048, "n_layers": 48, "vocab_size": 50277, "pad_vocab_size_multiple": 8},
- "2.8b": {"dim": 2560, "n_layers": 64, "vocab_size": 50277, "pad_vocab_size_multiple": 8},
- }
- def fetch_weights(model_name: str) -> Dict[str, Tensor]:
- if model_name not in MODELS:
- raise ValueError(f"Requested unknown mamba model: {model_name}")
- downloaded = fetch(f"https://huggingface.co/state-spaces/mamba-{model_name}/resolve/main/pytorch_model.bin?download=true")
- return torch_load(downloaded)
- def selective_scan_ref(
- u,
- delta,
- A,
- B,
- C,
- D=None,
- z=None,
- delta_bias=None,
- delta_softplus=False,
- return_last_state=False,
- ):
- """
- u: r(B D L)
- delta: r(B D L)
- A: c(D N) or r(D N)
- B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
- C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
- D: r(D)
- z: r(B D L)
- delta_bias: r(D), fp32
- out: r(B D L)
- last_state (optional): r(B D dstate) or c(B D dstate)
- """
- u = u.float()
- delta = delta.float()
- if delta_bias is not None:
- delta = delta + delta_bias[..., None].float()
- if delta_softplus:
- delta = delta.softplus()
- batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
- is_variable_B = len(B.shape) >= 3
- is_variable_C = len(C.shape) >= 3
- x = Tensor.zeros(batch, dim, dstate)
- ys = []
- deltaA = Tensor.einsum("bdl,dn->bdln", delta, A).exp()
- if not is_variable_B:
- deltaB_u = Tensor.einsum("bdl,dn,bdl->bdln", delta, B, u)
- else:
- if len(B.shape) == 3:
- deltaB_u = Tensor.einsum("bdl,bnl,bdl->bdln", delta, B, u)
- else:
- B = B.repeat((1, dim // B.shape[1], 1, 1))
- deltaB_u = Tensor.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
- if is_variable_C and len(C.shape) == 4:
- C = C.repeat((1, dim // C.shape[1], 1, 1))
- last_state = None
- for i in range(u.shape[2]):
- x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
- if not is_variable_C:
- y = Tensor.einsum("bdn,dn->bd", x, C)
- else:
- if len(C.shape) == 3:
- y = Tensor.einsum("bdn,bn->bd", x, C[:, :, i])
- else:
- y = Tensor.einsum("bdn,bdn->bd", x, C[:, :, :, i])
- if i == u.shape[2] - 1:
- last_state = x
- ys.append(y)
- y = Tensor.stack(*ys, dim=2) # (batch dim L)
- out = y if D is None else y + u * D.reshape((-1, 1))
- if z is not None:
- out = out * z.silu()
- return out if not return_last_state else (out, last_state)
- class MambaMixer:
- def __init__(
- self,
- dim,
- d_state=16,
- d_conv=4,
- expand=2,
- dt_rank="auto",
- dt_min=0.001,
- dt_max=0.1,
- dt_init="random",
- dt_scale=1.0,
- dt_init_floor=1e-4,
- conv_bias=True,
- bias=False,
- layer_idx=None,
- ):
- self.dim = dim
- self.d_state = d_state
- self.d_conv = d_conv
- self.expand = expand
- self.d_inner = self.expand * self.dim
- self.dt_rank = math.ceil(self.dim / 16) if dt_rank == "auto" else dt_rank
- self.layer_idx = layer_idx
- self.in_proj = nn.Linear(self.dim, self.d_inner * 2, bias=bias)
- self.conv1d = nn.Conv1d(in_channels=self.d_inner, out_channels=self.d_inner, bias=conv_bias,
- kernel_size=d_conv, groups=self.d_inner, padding=d_conv-1)
- self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False)
- self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
- # Initialize special dt projection to preserve variance at initialization
- dt_init_std = self.dt_rank**-0.5 * dt_scale
- if dt_init == "constant":
- self.dt_proj.weight = Tensor.full(self.dt_proj.weight.shape, dt_init_std)
- elif dt_init == "random":
- self.dt_proj.weight = Tensor.uniform(self.dt_proj.weight.shape, low=-dt_init_std, high=dt_init_std)
- else:
- raise NotImplementedError
- dt = Tensor.uniform(self.d_inner, low=math.log(dt_min), high=math.log(dt_max)).exp().maximum(dt_init_floor)
- inv_dt = dt + (1 - (-dt).exp()).log()
- self.dt_proj.bias.assign(inv_dt)
- # S4D real initialization
- self.A_log = Tensor.arange(1, self.d_state+1).repeat([self.d_inner, 1]).log()
- # D "skip" parameter
- self.D = Tensor.ones(self.d_inner) # Keep in fp32
- self.out_proj = nn.Linear(self.d_inner, self.dim, bias=bias)
- def __call__(self, hidden_states: Tensor):
- batch, seqlen, _ = hidden_states.shape
- if not hasattr(self, 'conv_state'):
- self.conv_state = Tensor.zeros(batch, self.dim * self.expand, self.d_conv).contiguous().realize()
- self.ssm_state = Tensor.zeros(batch, self.dim * self.expand, self.d_state).realize()
- xz = self.in_proj.weight @ hidden_states.permute(2,0,1).reshape(hidden_states.shape[2],hidden_states.shape[1]*hidden_states.shape[0])
- xz = xz.reshape(xz.shape[0],xz.shape[1]//seqlen, seqlen).permute(1,0,2)
- if self.in_proj.bias is not None:
- xz = xz + self.in_proj.bias.reshape((-1, 1))
- A = -self.A_log.exp()
- x, z = xz.chunk(2, dim=1)
- # Compute short convolution
- self.conv_state.assign(x[:, :, -self.d_conv :]) # Update state (B D W)
- x = self.conv1d(x)[..., :seqlen].swish()
- x_dbl = self.x_proj(x.permute(0,2,1).reshape(x.shape[0]*x.shape[2], x.shape[1]))
- dt, B, C = Tensor.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
- dt = self.dt_proj.weight @ dt.T
- dt = dt.reshape(dt.shape[0], dt.shape[1]//seqlen, seqlen).permute(1,0,2)
- B = B.reshape(B.shape[0]//seqlen, seqlen, B.shape[1]).permute(0,2,1)
- C = C.reshape(C.shape[0]//seqlen, seqlen, C.shape[1]).permute(0,2,1)
- # TODO: actually implement selective_scan_fn
- y = selective_scan_ref(x, dt, A, B, C, self.D, z=z, delta_bias=self.dt_proj.bias, delta_softplus=True,
- return_last_state=True)
- y, last_state = y
- self.ssm_state.assign(last_state).realize()
- y = y.permute(0,2,1)
- out = self.out_proj(y)
- return out
- else:
- return self.step(hidden_states)
- def step(self, hidden_states: Tensor):
- assert hidden_states.shape[1] == 1, f"Only support decoding with 1 token at a time for now, attempted {hidden_states.shape[1]}"
- xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
- x, z = xz.chunk(2, dim=-1) # (B D)
- # Conv step
- self.conv_state.assign(self.conv_state[:, :, 1:].cat(x.unsqueeze(-1), dim=-1).realize())
- x = (self.conv_state * self.conv1d.weight.squeeze(1)).sum(-1)
- if self.conv1d.bias is not None:
- x = x + self.conv1d.bias
- x = x.swish()
- x_db = self.x_proj(x) # (B dt_rank+2*d_state)
- dt = x_db[:, : self.dt_rank]
- B = x_db[:, self.dt_rank : (self.dt_rank + self.d_state)]
- C = x_db[:, (self.dt_rank + self.d_state) :]
- # Don't add dt_bias here
- dt = self.dt_proj.weight @ dt.T
- A = -self.A_log.exp()
- # SSM step
- dt = (dt + self.dt_proj.bias.unsqueeze(-1)).softplus()
- dA = Tensor.einsum("db,dn->bdn", dt, A).exp()
- dB = Tensor.einsum("db,bn->bdn", dt, B)
- self.ssm_state.assign(self.ssm_state * dA + x.unsqueeze(-1) * dB)
- y = Tensor.einsum("bdn,bn->bd", self.ssm_state, C)
- y = y + self.D * x
- y = y * z.swish() # (B D)
- out = self.out_proj(y)
- return out.unsqueeze(1)
- class MambaBlock:
- def __init__(self, dim: int, norm_eps: float = 1e-5, rms_norm: bool = True, layer_idx: Optional[int] = None):
- self.mixer = MambaMixer(dim, layer_idx=layer_idx)
- if rms_norm:
- self.norm = nn.RMSNorm(dim, norm_eps)
- else:
- raise NotImplementedError
- def __call__(self, hidden_states: Tensor, residual: Optional[Tensor] = None):
- residual = (hidden_states + residual) if residual is not None else hidden_states
- hidden_states = self.norm(residual)
- hidden_states = self.mixer(hidden_states)
- return hidden_states, residual
- class MambaBackbone:
- def __init__(self, dim: int, n_layers: int, vocab_size: int, rms_norm: bool = True, norm_eps: float = 1e-5):
- self.embedding = nn.Embedding(vocab_size, dim)
- self.layers = [MambaBlock(dim, rms_norm=rms_norm, layer_idx=i) for i in range(n_layers)]
- if rms_norm:
- self.norm_f = nn.RMSNorm(dim, norm_eps)
- def __call__(self, input_ids: Tensor) -> Any:
- hidden_states = self.embedding(input_ids)
- residual = None
- for layer in self.layers:
- hidden_states, residual = layer(hidden_states, residual)
- residual = (hidden_states + residual) if residual is not None else hidden_states
- hidden_states = self.norm_f(residual)
- return hidden_states
- class Mamba:
- def __init__(self, dim: int, n_layers: int, vocab_size: int, pad_vocab_size_multiple: int = 1):
- if vocab_size % pad_vocab_size_multiple != 0:
- vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
- self.backbone = MambaBackbone(dim, n_layers, vocab_size)
- self.lm_head = nn.Linear(dim, vocab_size, bias=False)
- self.forward_jit = TinyJit(self.forward)
- def forward(self, input_ids:Tensor):
- hidden_states = self.backbone(input_ids)
- return self.lm_head(hidden_states).realize()
- def __call__(self, input_ids):
- return self.forward(input_ids)
- @staticmethod
- def from_pretrained(model_name: str):
- weights = fetch_weights(model_name)
- model = Mamba(**MODELS[model_name])
- load_state_dict(model, weights)
- return model
- def generate(model, tokenizer, prompt: str, n_tokens_to_gen: int = 10, temp: bool = 1.0, sample: bool = False, top_k: int = None):
- tks = tokenizer(prompt)["input_ids"]
- while len(tks) < 4:
- tks = [50279] + tks
- # Loading in the prompt tokens
- logits = model.forward(Tensor([tks]))[:, -1, :]
- for _ in tqdm(range(n_tokens_to_gen), desc="Speed Gen"):
- # TODO: topk
- if sample:
- tok_Tens = (logits/temp).softmax().multinomial()
- else:
- tok_Tens = logits.argmax(axis=-1).unsqueeze(0)
- tok = tok_Tens.item()
- tks.append(tok)
- logits = model.forward_jit(tok_Tens)[:, -1, :]
- output_completions = ''.join([tokenizer.decode(output) for output in tks])
- return output_completions
- if __name__ == "__main__":
- ORIG_PROMPT = "Why is gravity "
- parser = argparse.ArgumentParser(description="Run Mamba in tinygrad", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
- parser.add_argument("--prompt", type=str, default="Why is gravity ", help="Prompt for LLM completion")
- parser.add_argument("--size", type=str, default="370m",
- help=f"Size of model to use [{', '.join([k for k in MODELS.keys()])}]")
- parser.add_argument("--n_tokens", type=int, default=10, help="Number of tokens to generate")
- parser.add_argument("--sample", dest="sample", action="store_true", help="Sample flag")
- parser.add_argument("--temp", type=float, default=1.0, help="Sampling temp has to be <=1.0")
- args = parser.parse_args()
- tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
- model = Mamba.from_pretrained(args.size)
- prompt = args.prompt
- num_toks = args.n_tokens
- sample = args.sample
- temp = args.temp
- s = time.time()
- tinyoutput = generate(model, tokenizer, prompt, n_tokens_to_gen=num_toks, sample=sample, temp=temp)
- print(tinyoutput)
- print('TIME: ', time.time() - s)
- TORCHOUTPUT = "Why is gravity \nso important?\nBecause it's the only"
- if ORIG_PROMPT == prompt and not sample and num_toks==10 and args.size=='370m': print('Outputs Match:', tinyoutput == TORCHOUTPUT)
|