Browse Source

tinygrad inference engine

Alex Cheema 1 year ago
parent
commit
490fa102a4
3 changed files with 423 additions and 7 deletions
  1. 20 7
      inference/test_inference_engine.py
  2. 191 0
      inference/tinygrad/inference.py
  3. 212 0
      inference/tinygrad/models/llama.py

+ 20 - 7
inference/test_inference_engine.py

@@ -1,21 +1,34 @@
 from inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
 from inference.inference_engine import InferenceEngine
 from inference.shard import Shard
+from inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
 import numpy as np
 
 # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
 async def test_inference_engine(inference_engine: InferenceEngine, model_id: str, input_data: np.array):
-    resp_full, _ = await inference_engine.infer_tensor(shard=Shard(model_id=model_id, start_layer=0, end_layer=1, n_layers=2), input_data=input_data)
+    # inference_engine.reset_shard(Shard("", 0,0,0))
+    resp_full, _ = await inference_engine.infer_prompt(shard=Shard(model_id=model_id, start_layer=0, end_layer=1, n_layers=2), prompt="In one word, what is the capital of USA? ")
 
-    resp1, _ = await inference_engine.infer_tensor(shard=Shard(model_id=model_id, start_layer=0, end_layer=0, n_layers=2), input_data=input_data)
-    resp2, _ = await inference_engine.infer_tensor(shard=Shard(model_id=model_id, start_layer=1, end_layer=1, n_layers=2), input_data=resp1)
+    print("resp_full", resp_full)
+    print("decoded", inference_engine.tokenizer.decode(resp_full))
 
-    assert np.array_equal(resp_full, resp2)
+    # inference_engine.reset_shard(Shard("", 0,0,0))
+
+    # resp1, _ = await inference_engine.infer_tensor(shard=Shard(model_id=model_id, start_layer=0, end_layer=0, n_layers=2), input_data=input_data)
+    # resp2, _ = await inference_engine.infer_tensor(shard=Shard(model_id=model_id, start_layer=1, end_layer=1, n_layers=2), input_data=resp1)
+
+    # assert np.array_equal(resp_full, resp2)
 
 import asyncio
 
+# asyncio.run(test_inference_engine(
+#     MLXDynamicShardInferenceEngine(),
+#     "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
+#     [1234]
+# ))
+
 asyncio.run(test_inference_engine(
-    MLXDynamicShardInferenceEngine(),
-    "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
+    TinygradDynamicShardInferenceEngine(),
+    "/Users/alex/Library/Caches/tinygrad/downloads/llama3-8b-sfr",
     [1234]
-))
+))

+ 191 - 0
inference/tinygrad/inference.py

@@ -0,0 +1,191 @@
+
+from pathlib import Path
+from typing import List
+import json, argparse, random, time
+import tiktoken
+from tiktoken.load import load_tiktoken_bpe
+from inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
+from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters
+from tinygrad import Tensor, dtypes, nn, Context, Device, GlobalCounters
+from tinygrad.helpers import Profiling, Timing, DEBUG, colored, fetch, tqdm
+from inference.shard import Shard
+from inference.inference_engine import InferenceEngine
+import numpy as np
+
+MODEL_PARAMS = {
+  "8B": {
+    "args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336},
+    "files": 1
+  },
+  "70B": {
+    "args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256,  "hidden_dim": 28672},
+    "files": 8
+  }
+}
+
+class Tokenizer:
+  pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
+  def __init__(self, model_path: str):
+    mergeable_ranks = load_tiktoken_bpe(model_path)
+    self.num_base_tokens = len(mergeable_ranks)
+    special_tokens = [
+      "<|begin_of_text|>",
+      "<|end_of_text|>",
+      "<|reserved_special_token_0|>",
+      "<|reserved_special_token_1|>",
+      "<|reserved_special_token_2|>",
+      "<|reserved_special_token_3|>",
+      "<|start_header_id|>",
+      "<|end_header_id|>",
+      "<|reserved_special_token_4|>",
+      "<|eot_id|>",
+    ] + [
+      f"<|reserved_special_token_{i}|>"
+      for i in range(5, 256 - 5)
+    ]
+    self.special_tokens = {token: len(mergeable_ranks) + i for i, token in enumerate(special_tokens)}
+
+    self.model = tiktoken.Encoding(name=model_path, pat_str=self.pat_str, mergeable_ranks=mergeable_ranks, special_tokens=self.special_tokens)
+
+  @property
+  def bos_id(self): return self.special_tokens["<|begin_of_text|>"]
+  @property
+  def stop_tokens(self): return {self.special_tokens["<|end_of_text|>"], self.special_tokens["<|eot_id|>"]}
+
+  def decode(self, toks): return self.model.decode([t for t in toks if t < self.num_base_tokens])
+  def encode(self, text, allow_special=False):
+    return self.model.encode(text, allowed_special="all" if allow_special else set(), disallowed_special=set())
+
+# **** helper functions ****
+def concat_weights(models, device=None):
+  def convert(name) -> Tensor:
+    disk_tensors: List[Tensor] = [model[name] for model in models]
+    if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1:
+      return disk_tensors[0].to(device=device)
+    axis = 1 if name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0
+    lazy_tensors = [data.to(device=device) for data in disk_tensors]
+    return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
+  return {name: convert(name) for name in {name: None for model in models for name in model}}
+
+def load(fn:str):
+  if fn.endswith('.index.json'):
+    with open(fn) as fp: weight_map = json.load(fp)['weight_map']
+    parts = {n: load(str(Path(fn).parent / Path(n).name)) for n in set(weight_map.values())}
+    return {k: parts[n][k] for k, n in weight_map.items()}
+  elif fn.endswith(".safetensors"):
+    return safe_load(fn)
+  else:
+    return torch_load(fn)
+
+def build_transformer(model_path: Path, model_size="8B", quantize=None, device=None):
+  # build model
+  linear = nn.Linear
+  with Context(THREEFRY=0):
+    model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True)
+
+  # load weights
+  if model_path.is_dir():
+    if (model_path / "model.safetensors.index.json").exists(): weights = load(str(model_path / "model.safetensors.index.json"))
+    elif (model_path / "model.safetensors").exists(): weights = load(str(model_path / "model.safetensors"))
+    else: weights = concat_weights([load(str(model_path / f"consolidated.{i:02d}.pth")) for i in range(MODEL_PARAMS[model_size]["files"])], device[0] if isinstance(device, tuple) else device)
+  else:
+    weights = load(str(model_path))
+  if "model.embed_tokens.weight" in weights:
+    weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"])
+  weights = fix_bf16(weights)
+
+  with Context(BEAM=0):
+    # quantize
+    if quantize is not None:
+      weights = linear.quantize(weights, device)
+      for _,v in weights.items(): v.realize()
+
+    # shard
+    if isinstance(device, tuple):
+      for k,v in nn.state.get_state_dict(model).items():
+        if 'scale' in k: v.shard_(device, axis=None)  # from quantized
+        elif '.attention.' in k: v.shard_(device, axis=-1)
+        elif '.feed_forward.w1.' in k: v.shard_(device, axis=0)
+        elif '.feed_forward.w3.' in k: v.shard_(device, axis=0)
+        elif '.feed_forward.' in k: v.shard_(device, axis=-1)
+        elif 'tok_embeddings.weight' in k: v.shard_(device, axis=0)
+        elif 'output.weight' in k: v.shard_(device, axis=0)
+        else: v.shard_(device, axis=None)
+
+    # replace weights in model
+    load_state_dict(model, weights, strict=False, consume=True)
+  return model
+
+# default settings
+TEMPERATURE = 0.85
+TOP_K = 25
+TOP_P = 0.9
+ALPHA_F = 0.1
+ALPHA_P = 0.0
+
+last_seen_toks = []
+def prefill(model, toks, start_pos=0):
+  global last_seen_toks
+
+  # we can skip part of the prompt if it is the same as last and start_pos=0
+  if start_pos == 0:
+    for i, (a, b) in enumerate(zip(toks, last_seen_toks)):
+      if a != b: break
+    else: i = min(len(toks), len(last_seen_toks))
+    start_pos += i
+    last_seen_toks = toks
+    toks = toks[i:]
+
+  # prefill the model
+  for tok in tqdm(toks):
+    GlobalCounters.reset()
+    model(Tensor([[tok]]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).realize()
+    start_pos += 1
+  return start_pos
+
+class TinygradDynamicShardInferenceEngine(InferenceEngine):
+    def __init__(self):
+        self.shard = None
+
+    async def infer_prompt(self, shard: Shard, prompt: str) -> (np.ndarray, bool):
+        def encode_role(role: str):
+            return [self.tokenizer.special_tokens["<|start_header_id|>"]] + self.tokenizer.encode(role) + [self.tokenizer.special_tokens["<|end_header_id|>"]] + self.tokenizer.encode("\n\n")
+        def encode_message(role: str, content: str):
+            return encode_role(role) + self.tokenizer.encode(content.strip()) + [self.tokenizer.special_tokens["<|eot_id|>"]]
+
+        await self.ensure_shard(shard)
+        print([self.tokenizer.encode(prompt)])
+
+        toks = [self.tokenizer.bos_id] + encode_message("user", prompt) + encode_role("assistant")
+        start_pos = prefill(self.model, toks[:-1])
+        last_tok = toks[-1]
+        
+        output_data = np.array(self.model(Tensor([[last_tok]]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist())
+        start_pos += 1
+
+        return output_data, output_data.size == 1 and output_data.item() in self.tokenizer.stop_tokens
+
+    async def infer_tensor(self, shard: Shard, input_data: np.ndarray) -> (np.ndarray, bool):
+        await self.ensure_shard(shard)
+        output_data: np.ndarray = np.array(self.model(Tensor([input_data]), 0, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist())
+        return output_data, output_data.size == 1 and output_data.item() in self.tokenizer.stop_tokens
+
+    async def reset_shard(self, shard: Shard):
+        await self.ensure_shard(shard)
+
+        print(f"Resetting shard: {shard}")
+        self.model.reset()
+
+    async def ensure_shard(self, shard: Shard):
+        if self.shard == shard:
+            return
+
+        model_path = Path(shard.model_id)
+        size = "8B" # one of 8B or 70B for now
+        model = build_transformer(model_path, model_size=size)
+        tokenizer = Tokenizer(str((model_path if model_path.is_dir() else model_path.parent) / "tokenizer.model"))
+
+        self.shard = shard
+        self.model = model
+        self.tokenizer = tokenizer
+

+ 212 - 0
inference/tinygrad/models/llama.py

@@ -0,0 +1,212 @@
+from typing import Tuple, Union, Optional, Dict, Any
+from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
+from tinygrad.helpers import getenv
+
+# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
+def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half) -> Tensor:
+  freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
+  freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
+  # TODO: move dtype outside this
+  return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim//2, 2)
+
+# (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
+def complex_mult(A, c, d):
+  a,b = A[..., 0:1], A[..., 1:2]
+  ro = a*c - b*d
+  co = a*d + b*c
+  return ro.cat(co, dim=-1)
+
+def apply_rotary_emb(xq:Tensor, xk:Tensor, freqs_cis:Tensor) -> Tuple[Tensor, Tensor]:
+  assert freqs_cis.shape[1] == xq.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
+  xq = xq.reshape(*xq.shape[0:-1], -1, 2)
+  xk = xk.reshape(*xk.shape[0:-1], -1, 2)
+  assert len(xq.shape) == len(xk.shape) == len(freqs_cis.shape) == 5
+  c, d = freqs_cis[..., 0:1], freqs_cis[..., 1:2]
+  xq_out = complex_mult(xq, c, d)
+  xk_out = complex_mult(xk, c, d)
+  return xq_out.flatten(3), xk_out.flatten(3)
+
+def repeat_kv(x:Tensor, n_rep:int) -> Tensor:
+  bs, seqlen, n_kv_heads, head_dim = x.shape
+  if n_rep == 1: return x
+  # NOTE: this is different from x.repeat((1, 1, n_rep, 1))
+  return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
+
+class Attention:
+  def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
+    self.n_heads = n_heads
+    self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
+    self.head_dim = dim // n_heads
+    self.n_rep = self.n_heads // self.n_kv_heads
+    self.max_context = max_context
+
+    self.wq = linear(dim, self.n_heads * self.head_dim, bias=False)
+    self.wk = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
+    self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
+    self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
+
+  def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]) -> Tensor:
+    xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
+    xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
+    xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
+    xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
+
+    xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
+    bsz, seqlen, _, _ = xq.shape
+
+    # create kv cache
+    if not hasattr(self, "cache_kv"):
+      self.cache_kv = Tensor.zeros(2, bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype).contiguous().realize()
+      if isinstance(x.device, tuple):
+        # TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded
+        self.cache_kv.shard_((x.device), axis=None).realize()
+
+    # update the cache
+    assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}"
+    self.cache_kv.shrink((None, None, (start_pos, start_pos+seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
+
+    keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xk
+    values = self.cache_kv[1].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xv
+
+    keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
+    xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
+    attn = xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2)
+    attn = attn.reshape(bsz, seqlen, -1)
+    return self.wo(attn)
+
+class FeedForward:
+  def __init__(self, dim:int, hidden_dim:int, linear=nn.Linear):
+    self.w1 = linear(dim, hidden_dim, bias=False)
+    self.w2 = linear(hidden_dim, dim, bias=False)
+    self.w3 = linear(dim, hidden_dim, bias=False) # the gate in Gated Linear Unit
+
+  def __call__(self, x:Tensor) -> Tensor:
+    return self.w2(self.w1(x).silu() * self.w3(x)) # SwiGLU [arxiv/2002.05202, eq (5)]
+
+class TransformerBlock:
+  def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int, linear=nn.Linear, feed_forward=FeedForward):
+    self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear)
+    self.feed_forward = feed_forward(dim, hidden_dim, linear)
+    self.attention_norm = nn.RMSNorm(dim, norm_eps)
+    self.ffn_norm = nn.RMSNorm(dim, norm_eps)
+
+  def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]):
+    h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
+    return (h + self.feed_forward(self.ffn_norm(h))).contiguous()
+
+# standard openai sampling
+def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
+  assert logits.ndim == 1, "only works on 1d tensors"
+  assert 0 <= p <= 1, "p must be between 0 and 1"
+  assert 0 <= k <= logits.numel(), "k must be between 0 and numel"
+
+  # if temperature is very low just use argmax
+  if temp < 1e-6: return logits.argmax()
+
+  # alpha sampling
+  if af or ap:
+    if not hasattr(sample, "alpha_counter"):
+      setattr(sample, "alpha_counter", Tensor.zeros_like(logits, dtype=dtypes.int32).contiguous())
+    logits = logits - (sample.alpha_counter * af + (sample.alpha_counter > 0) * ap)
+
+  # replace NaNs with -inf
+  logits = (logits != logits).where(-float("inf"), logits)
+
+  # softmax
+  t = (logits / temp).softmax()
+
+  counter, counter2 = Tensor.arange(t.numel(), device=logits.device).contiguous(), Tensor.arange(t.numel() - 1, -1, -1, device=logits.device).contiguous()
+  # top k
+  if k:
+    output, output_indices = Tensor.zeros(k, device=logits.device).contiguous(), Tensor.zeros(k, device=logits.device, dtype=dtypes.int32).contiguous()
+    for i in range(k):
+      t_argmax = (t.numel() - ((t == (t_max := t.max())) * counter2).max() - 1).cast(dtypes.default_int)
+      output = output + t_max.unsqueeze(0).pad(((i, k - i - 1),))
+      output_indices = output_indices + t_argmax.unsqueeze(0).pad(((i, k - i - 1),))
+      t = (counter == t_argmax).where(0, t)
+
+    # approximate top p
+    # because we are already limited to top k elements we can do top p "without sorting"
+    output_cumsum = output[::-1]._cumsum()[::-1] + t.sum()
+    output = (output_cumsum >= (1 - p)) * output
+    output_indices = (output_cumsum >= (1 - p)) * output_indices
+
+    # sample
+    output_idx = output.multinomial()
+    output_token = output_indices[output_idx]
+  else:
+    output_token = t.multinomial()
+
+  # increase alpha counter
+  if af or ap:
+    sample.alpha_counter = (counter == output_token).where(sample.alpha_counter + 1, sample.alpha_counter)
+
+  return output_token
+
+class Transformer:
+  def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward):
+    self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
+    self.norm = nn.RMSNorm(dim, norm_eps)
+    self.tok_embeddings = nn.Embedding(vocab_size, dim)
+    self.output = nn.Linear(dim, vocab_size, bias=False)
+    self.max_context = max_context
+    self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context * 2, rope_theta)
+    self.forward_jit = TinyJit(self.forward) if jit else None
+
+  def forward(self, tokens:Tensor, start_pos:Union[Variable,int], temperature:float, top_k:int, top_p:float, alpha_f:float, alpha_p:float):
+    _bsz, seqlen = tokens.shape
+    freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos+seqlen),None,None,None))
+
+    h = self.tok_embeddings(tokens)
+    mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos+1).realize() if seqlen > 1 else None
+    for i, layer in enumerate(self.layers):
+      h = layer(h, start_pos, freqs_cis, mask)
+      print(f"layer {i}", h.tolist().__str__()[0:100])
+    logits = self.output(self.norm(h)).float()[:, -1, :]
+
+    return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p).realize()
+
+  def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0, top_k:int=0, top_p:float=0.8, alpha_f:float=0.0, alpha_p:float=0.0):
+    # TODO: better way to handle the first call v.s. the rest?
+    if tokens.shape[0:2] == (1,1) and self.forward_jit is not None:
+      return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
+    return self.forward(tokens, start_pos, temperature, top_k, top_p, alpha_f, alpha_p)
+
+  def reset(self):
+    for layer in self.layers:
+      print(f"reset layer: {layer.attention.cache_kv}")
+      layer.attention.cache_kv = layer.attention.cache_kv.zeros_like()
+
+# *** helpers ***
+
+def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int):
+  def permute(v: Tensor, n_heads: int):
+    return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
+
+  keymap = {
+    "model.embed_tokens.weight": "tok_embeddings.weight",
+    **{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))},
+    **{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
+    **{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(len(model.layers))},
+    **{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
+    "model.norm.weight": "norm.weight",
+    "lm_head.weight": "output.weight",
+  }
+  sd = {}
+  for k, v in weights.items():
+    if ".rotary_emb." in k: continue
+    v = v.to(Device.DEFAULT)
+    if "model.layers" in k:
+      if "q_proj" in k:
+        v = permute(v, n_heads)
+      elif "k_proj" in k:
+        v = permute(v, n_kv_heads)
+    sd[keymap[k]] = v
+  return sd
+
+def fix_bf16(weights:Dict[Any, Tensor]):
+  if getenv("SUPPORT_BF16", 1):
+    # TODO: without casting to float16, 70B llama OOM on tinybox.
+    return {k:v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}
+  # TODO: check if device supports bf16
+  return {k:v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}