Browse Source

refactor tinygrad, only load necessary layers for each shard fixes #128, enable JIT (much faster), prefill all layers not just the first shard fixes #12, use new ShardDownloader for more robust, parallel downloads

Alex Cheema 10 months ago
parent
commit
2be446546f

+ 4 - 3
exo/inference/test_inference_engine.py

@@ -11,6 +11,7 @@ import numpy as np
 async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str):
   prompt = "In a single word only, what is the last name of the current president of the USA?"
   resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
+
   next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor(
     "A",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
@@ -52,7 +53,7 @@ asyncio.run(
 
 # TODO: Need more memory or a smaller model
 # asyncio.run(test_inference_engine(
-#     TinygradDynamicShardInferenceEngine(),
-#     TinygradDynamicShardInferenceEngine(),
-#     "mlx-community/Meta-Llama-3-8B-Instruct",
+#     TinygradDynamicShardInferenceEngine(HFShardDownloader()),
+#     TinygradDynamicShardInferenceEngine(HFShardDownloader()),
+#     "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R",
 # ))

+ 48 - 150
exo/inference/tinygrad/inference.py

@@ -1,199 +1,97 @@
 from pathlib import Path
-from typing import List, Optional
+from typing import List
 import json
 from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
-from tinygrad.nn.state import safe_load, torch_load, load_state_dict
-from tinygrad import Tensor, nn, Context, GlobalCounters
-from tinygrad.helpers import tqdm
 from exo.inference.shard import Shard
+from tinygrad.nn.state import safe_load, torch_load, load_state_dict
+from tinygrad import Tensor, dtypes, nn, Context
+from transformers import AutoTokenizer
 from exo.inference.inference_engine import InferenceEngine
+from typing import Optional, Tuple
 import numpy as np
+from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
 from exo.download.shard_download import ShardDownloader
 
+Tensor.no_grad = True
+# default settings
+TEMPERATURE = 0.85
+TOP_K = 25
+TOP_P = 0.9
+ALPHA_F = 0.1
+ALPHA_P = 0.0
 MODEL_PARAMS = {
   "8B": {
-    "args": {
-      "dim": 4096,
-      "n_heads": 32,
-      "n_kv_heads": 8,
-      "n_layers": 32,
-      "norm_eps": 1e-5,
-      "rope_theta": 500000,
-      "vocab_size": 128256,
-      "hidden_dim": 14336,
-    },
-    "files": 1,
+    "args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336},
+    "files": 1
   },
   "70B": {
-    "args": {
-      "dim": 8192,
-      "n_heads": 64,
-      "n_kv_heads": 8,
-      "n_layers": 80,
-      "norm_eps": 1e-5,
-      "rope_theta": 500000,
-      "vocab_size": 128256,
-      "hidden_dim": 28672,
-    },
-    "files": 8,
-  },
+    "args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256,  "hidden_dim": 28672},
+    "files": 8
+  }
 }
 
-
-
-# **** helper functions ****
-
-def concat_weights(models, device=None):
-  def convert(name) -> Tensor:
-    disk_tensors: List[Tensor] = [model[name] for model in models]
-    if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1:
-      return disk_tensors[0].to(device=device)
-    axis = 1 if name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0
-    lazy_tensors = [data.to(device=device) for data in disk_tensors]
-    return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
-
-  return {name: convert(name) for name in {name: None for model in models for name in model}}
-
-
-def load(fn: str):
-  if fn.endswith(".index.json"):
-    with open(fn) as fp:
-      weight_map = json.load(fp)["weight_map"]
-    parts = {n: load(str(Path(fn).parent / Path(n).name)) for n in set(weight_map.values())}
-    return {k: parts[n][k] for k, n in weight_map.items()}
-  elif fn.endswith(".safetensors"):
-    return safe_load(fn)
-  else:
-    return torch_load(fn)
-
-
-def build_transformer(model_path: Path, shard: Shard, model_size="8B", quantize=None, device=None):
+def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=None):
   # build model
   linear = nn.Linear
   with Context(THREEFRY=0):
-    model = Transformer(**MODEL_PARAMS[model_size]["args"], shard=shard, linear=linear, max_context=8192, jit=False)
+    model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True, shard=shard)
 
   # load weights
   if model_path.is_dir():
-    if (model_path / "model.safetensors.index.json").exists():
-      weights = load(str(model_path / "model.safetensors.index.json"))
-    elif (model_path / "model.safetensors").exists():
-      weights = load(str(model_path / "model.safetensors"))
-    else:
-      weights = concat_weights(
-        [load(str(model_path / f"consolidated.{i:02d}.pth")) for i in range(MODEL_PARAMS[model_size]["files"])],
-        device[0] if isinstance(device, tuple) else device,
-      )
+    if (model_path / "model.safetensors.index.json").exists(): weights = load(str(model_path / "model.safetensors.index.json"), shard)
+    elif (model_path / "model.safetensors").exists(): weights = load(str(model_path / "model.safetensors"), shard)
+    else: weights = concat_weights([load(str(model_path / f"consolidated.{i:02d}.pth"), shard) for i in range(MODEL_PARAMS[model_size]["files"])], device[0] if isinstance(device, tuple) else device)
   else:
-    weights = load(str(model_path))
-
+    weights = load(str(model_path), shard)
   if "model.embed_tokens.weight" in weights:
-    weights = convert_from_huggingface(
-      weights,
-      model,
-      MODEL_PARAMS[model_size]["args"]["n_heads"],
-      MODEL_PARAMS[model_size]["args"]["n_kv_heads"],
-      shard=shard,
-    )
+    weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"])
   weights = fix_bf16(weights)
 
   with Context(BEAM=0):
-    # quantize
-    if quantize is not None:
-      weights = linear.quantize(weights, device)
-      for _, v in weights.items():
-        v.realize()
-
-    # shard
-    if isinstance(device, tuple):
-      for k, v in nn.state.get_state_dict(model).items():
-        if "scale" in k:
-          v.shard_(device, axis=None)  # from quantized
-        elif ".attention." in k:
-          v.shard_(device, axis=-1)
-        elif ".feed_forward.w1." in k:
-          v.shard_(device, axis=0)
-        elif ".feed_forward.w3." in k:
-          v.shard_(device, axis=0)
-        elif ".feed_forward." in k:
-          v.shard_(device, axis=-1)
-        elif "tok_embeddings.weight" in k:
-          v.shard_(device, axis=0)
-        elif "output.weight" in k:
-          v.shard_(device, axis=0)
-        else:
-          v.shard_(device, axis=None)
-
     # replace weights in model
-    load_state_dict(model, weights, strict=False, consume=True)
+    load_state_dict(model, weights, strict=False, consume=False) # consume=True
   return model
 
-
-# default settings
-TEMPERATURE = 0  # 0.85
-TOP_K = 25
-TOP_P = 0.9
-ALPHA_F = 0.1
-ALPHA_P = 0.0
-
-
-def prefill(model, toks, start_pos=0):
-  # prefill the model
-  for tok in tqdm(toks):
-    GlobalCounters.reset()
-    model(Tensor([[tok]]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).realize()
-    start_pos += 1
-  return start_pos
-
-
 class TinygradDynamicShardInferenceEngine(InferenceEngine):
   def __init__(self, shard_downloader: ShardDownloader):
     self.shard = None
     self.shard_downloader = shard_downloader
 
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
-    # TODO: we need to refactor models/llamaa to handle per-request-kv-cache. right now it's shared between requests.
     await self.ensure_shard(shard)
-    start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0
+    start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
+    n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
 
     toks = self.tokenizer.encode(prompt)
-    start_pos = prefill(self.model, toks[:-1], start_pos=start_pos)
-    last_tok = toks[-1]
+    h = self.model(Tensor([toks]), start_pos, TEMPERATURE)
 
-    output_data = np.array([self.model(Tensor([[last_tok]]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist()])
-    if output_data.size == 1:
-      start_pos += 1
-
-    return (
-      output_data,
-      json.dumps({"start_pos": start_pos}),
-      output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id],
-    )
+    if h.shape == (1,):
+      start_pos += len(toks)
+      n_captured_toks = 1
+      return np.array([[h.item()]]), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), h.item() == self.tokenizer.eos_token_id
+    else:
+      n_captured_toks += len(toks)
+      return h.numpy(), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks + 1}), False
 
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
     await self.ensure_shard(shard)
-    start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0
+    start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
+    n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
 
-    output_data: np.ndarray = np.array([self.model(Tensor([input_data]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist()])
-    if output_data.size == 1:
-      start_pos += 1
+    h = self.model(Tensor(input_data), start_pos, TEMPERATURE).realize()
 
-    return (
-      output_data,
-      json.dumps({"start_pos": start_pos}),
-      output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id],
-    )
+    if h.shape == (1,):
+      start_pos += n_captured_toks
+      n_captured_toks = 1
+      return np.array([[h.item()]]), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), h.item() == self.tokenizer.eos_token_id
+    else:
+      return h.numpy(), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), False
 
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:
       return
 
     model_path = await self.shard_downloader.ensure_shard(shard)
-    print(f"{model_path=}")
-    model = build_transformer(model_path, shard=shard, model_size="8B" if "8b" in shard.model_id else "70B" if "70b" in shard.model_id else "8B")
-    from transformers import AutoTokenizer
-    tokenizer = AutoTokenizer.from_pretrained(str((model_path if model_path.is_dir() else model_path.parent)))
-
+    self.model = build_transformer(model_path, shard, model_size="8B")
+    self.tokenizer = AutoTokenizer.from_pretrained(str((model_path if model_path.is_dir() else model_path.parent)))
     self.shard = shard
-    self.model = model
-    self.tokenizer = tokenizer

+ 53 - 125
exo/inference/tinygrad/models/llama.py

@@ -1,26 +1,22 @@
 from typing import Tuple, Union, Optional, Dict, Any
 from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
 from tinygrad.helpers import getenv
-from exo.inference.shard import Shard
-
 
 # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
 def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half) -> Tensor:
-  freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[: (dim // 2)] / dim))
+  freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
   freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
   # TODO: move dtype outside this
-  return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim // 2, 2)
-
+  return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim//2, 2)
 
 # (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
 def complex_mult(A, c, d):
-  a, b = A[..., 0:1], A[..., 1:2]
-  ro = a * c - b * d
-  co = a * d + b * c
+  a,b = A[..., 0:1], A[..., 1:2]
+  ro = a*c - b*d
+  co = a*d + b*c
   return ro.cat(co, dim=-1)
 
-
-def apply_rotary_emb(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> Tuple[Tensor, Tensor]:
+def apply_rotary_emb(xq:Tensor, xk:Tensor, freqs_cis:Tensor) -> Tuple[Tensor, Tensor]:
   assert freqs_cis.shape[1] == xq.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
   xq = xq.reshape(*xq.shape[0:-1], -1, 2)
   xk = xk.reshape(*xk.shape[0:-1], -1, 2)
@@ -30,19 +26,16 @@ def apply_rotary_emb(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> Tuple[Tensor,
   xk_out = complex_mult(xk, c, d)
   return xq_out.flatten(3), xk_out.flatten(3)
 
-
-def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
+def repeat_kv(x:Tensor, n_rep:int) -> Tensor:
   bs, seqlen, n_kv_heads, head_dim = x.shape
-  if n_rep == 1:
-    return x
+  if n_rep == 1: return x
   # NOTE: this is different from x.repeat((1, 1, n_rep, 1))
   return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
 
-
 class Attention:
   def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
     self.n_heads = n_heads
-    self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads  # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
+    self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
     self.head_dim = dim // n_heads
     self.n_rep = self.n_heads // self.n_kv_heads
     self.max_context = max_context
@@ -52,8 +45,14 @@ class Attention:
     self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
     self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
 
-  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]) -> Tensor:
-    xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
+  def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]) -> Tensor:
+    if getenv("WQKV"):
+      if not hasattr(self, 'wqkv'): self.wqkv = Tensor.cat(self.wq.weight, self.wk.weight, self.wv.weight)
+      xqkv = x @ self.wqkv.T
+      xq, xk, xv = xqkv.split([self.wq.weight.shape[0], self.wk.weight.shape[0], self.wv.weight.shape[0]], dim=2)
+    else:
+      xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
+
     xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
     xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
     xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
@@ -66,14 +65,14 @@ class Attention:
       self.cache_kv = Tensor.zeros(2, bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype).contiguous().realize()
       if isinstance(x.device, tuple):
         # TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded
-        self.cache_kv.shard_((x.device), axis=None).realize()
+        self.cache_kv.shard_((x.device), axis=3 if getenv("SHARD_KVCACHE") else None).realize()
 
     # update the cache
     assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}"
-    self.cache_kv.shrink((None, None, (start_pos, start_pos + seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
+    self.cache_kv.shrink((None, None, (start_pos, start_pos+seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
 
-    keys = self.cache_kv[0].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xk
-    values = self.cache_kv[1].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xv
+    keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xk
+    values = self.cache_kv[1].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xv
 
     keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
     xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
@@ -81,39 +80,26 @@ class Attention:
     attn = attn.reshape(bsz, seqlen, -1)
     return self.wo(attn)
 
-
 class FeedForward:
-  def __init__(self, dim: int, hidden_dim: int, linear=nn.Linear):
+  def __init__(self, dim:int, hidden_dim:int, linear=nn.Linear):
     self.w1 = linear(dim, hidden_dim, bias=False)
     self.w2 = linear(hidden_dim, dim, bias=False)
-    self.w3 = linear(dim, hidden_dim, bias=False)  # the gate in Gated Linear Unit
-
-  def __call__(self, x: Tensor) -> Tensor:
-    return self.w2(self.w1(x).silu() * self.w3(x))  # SwiGLU [arxiv/2002.05202, eq (5)]
+    self.w3 = linear(dim, hidden_dim, bias=False) # the gate in Gated Linear Unit
 
+  def __call__(self, x:Tensor) -> Tensor:
+    return self.w2(self.w1(x).silu() * self.w3(x)) # SwiGLU [arxiv/2002.05202, eq (5)]
 
 class TransformerBlock:
-  def __init__(
-    self,
-    dim: int,
-    hidden_dim: int,
-    n_heads: int,
-    n_kv_heads: int,
-    norm_eps: float,
-    max_context: int,
-    linear=nn.Linear,
-    feed_forward=FeedForward,
-  ):
+  def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int, linear=nn.Linear, feed_forward=FeedForward):
     self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear)
     self.feed_forward = feed_forward(dim, hidden_dim, linear)
     self.attention_norm = nn.RMSNorm(dim, norm_eps)
     self.ffn_norm = nn.RMSNorm(dim, norm_eps)
 
-  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]):
+  def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]):
     h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
     return (h + self.feed_forward(self.ffn_norm(h))).contiguous()
 
-
 # standard openai sampling
 def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
   assert logits.ndim == 1, "only works on 1d tensors"
@@ -121,8 +107,7 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
   assert 0 <= k <= logits.numel(), "k must be between 0 and numel"
 
   # if temperature is very low just use argmax
-  if temp < 1e-6:
-    return logits.argmax()
+  if temp < 1e-6: return logits.argmax()
 
   # alpha sampling
   if af or ap:
@@ -136,16 +121,10 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
   # softmax
   t = (logits / temp).softmax()
 
-  counter, counter2 = (
-    Tensor.arange(t.numel(), device=logits.device).contiguous(),
-    Tensor.arange(t.numel() - 1, -1, -1, device=logits.device).contiguous(),
-  )
+  counter, counter2 = Tensor.arange(t.numel(), device=logits.device).contiguous(), Tensor.arange(t.numel() - 1, -1, -1, device=logits.device).contiguous()
   # top k
   if k:
-    output, output_indices = (
-      Tensor.zeros(k, device=logits.device).contiguous(),
-      Tensor.zeros(k, device=logits.device, dtype=dtypes.int32).contiguous(),
-    )
+    output, output_indices = Tensor.zeros(k, device=logits.device).contiguous(), Tensor.zeros(k, device=logits.device, dtype=dtypes.int32).contiguous()
     for i in range(k):
       t_argmax = (t.numel() - ((t == (t_max := t.max())) * counter2).max() - 1).cast(dtypes.default_int)
       output = output + t_max.unsqueeze(0).pad(((i, k - i - 1),))
@@ -170,84 +149,48 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
 
   return output_token
 
+from exo.inference.shard import Shard
 
 class Transformer:
-  def __init__(
-    self,
-    dim: int,
-    hidden_dim: int,
-    n_heads: int,
-    n_layers: int,
-    norm_eps: float,
-    vocab_size,
-    shard: Shard,
-    linear=nn.Linear,
-    n_kv_heads=None,
-    rope_theta=10000,
-    max_context=1024,
-    jit=True,
-    feed_forward=FeedForward,
-  ):
-    self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(shard.end_layer - shard.start_layer + 1)]
+  def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, shard: Shard=None, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward):
+    self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
     self.norm = nn.RMSNorm(dim, norm_eps)
     self.tok_embeddings = nn.Embedding(vocab_size, dim)
     self.output = nn.Linear(dim, vocab_size, bias=False)
     self.max_context = max_context
-    self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context * 2, rope_theta)
+    self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context * 2, rope_theta).contiguous()
     self.forward_jit = TinyJit(self.forward) if jit else None
     self.shard = shard
 
-  def forward(
-    self,
-    h: Tensor,
-    start_pos: Union[Variable, int],
-    temperature: float,
-    top_k: int,
-    top_p: float,
-    alpha_f: float,
-    alpha_p: float,
-  ):
-    seqlen = h.shape[1]
-    freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos + seqlen), None, None, None))
+  def forward(self, x:Tensor, start_pos:Union[Variable,int], temperature:float, top_k:int, top_p:float, alpha_f:float, alpha_p:float):
+    seqlen = x.shape[1]
+    freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos+seqlen),None,None,None))
+    mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-100000000"), dtype=x.dtype, device=x.device).triu(start_pos+1).realize() if seqlen > 1 else None
 
     if self.shard.is_first_layer():
-      h = self.tok_embeddings(h)
-    mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos + 1).realize() if seqlen > 1 else None
+      h = self.tok_embeddings(x)
+    else:
+      h = x
 
-    for layer in self.layers:
+    for i in range(self.shard.start_layer, self.shard.end_layer + 1):
+      layer = self.layers[i]
       h = layer(h, start_pos, freqs_cis, mask)
 
     if self.shard.is_last_layer():
       logits = self.output(self.norm(h)).float()[:, -1, :]
       return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p).realize()
     else:
-      return h.realize()
-
-  def __call__(
-    self,
-    tokens: Tensor,
-    start_pos: Variable,
-    temperature: float = 0.0,
-    top_k: int = 0,
-    top_p: float = 0.8,
-    alpha_f: float = 0.0,
-    alpha_p: float = 0.0,
-  ):
+      return h
+
+  def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0, top_k:int=0, top_p:float=0.8, alpha_f:float=0.0, alpha_p:float=0.0):
     # TODO: better way to handle the first call v.s. the rest?
-    # if tokens.shape[0:2] == (1,1) and self.forward_jit is not None:
-    #   return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
+    if tokens.shape[0:2] == (1,1) and self.forward_jit is not None:
+      return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
     return self.forward(tokens, start_pos, temperature, top_k, top_p, alpha_f, alpha_p)
 
-  def reset(self):
-    for layer in self.layers:
-      if hasattr(layer.attention, "cache_kv"):
-        layer.attention.cache_kv = layer.attention.cache_kv.zeros_like()
-
-
 # *** helpers ***
 
-
-def convert_from_huggingface(weights: Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int, shard: Shard):
+def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int):
   def permute(v: Tensor, n_heads: int):
     return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
 
@@ -255,30 +198,16 @@ def convert_from_huggingface(weights: Dict[str, Tensor], model: Transformer, n_h
     "model.embed_tokens.weight": "tok_embeddings.weight",
     **{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))},
     **{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
-    **{f"model.layers.{l}.self_attn.{x}_proj.biases": f"layers.{l}.attention.w{x}.bias" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
-    **{f"model.layers.{l}.self_attn.{x}_proj.scales": f"layers.{l}.attention.w{x}.scale" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
     **{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(len(model.layers))},
-    **{f"model.layers.{l}.post_attention_layernorm.biases": f"layers.{l}.ffn_norm.bias" for l in range(len(model.layers))},
     **{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
-    **{f"model.layers.{l}.mlp.{x}_proj.biases": f"layers.{l}.feed_forward.w{y}.bias" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
-    **{f"model.layers.{l}.mlp.{x}_proj.scales": f"layers.{l}.feed_forward.w{y}.scale" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
     "model.norm.weight": "norm.weight",
     "lm_head.weight": "output.weight",
-    "lm_head.biases": "output.bias",
-    "lm_head.scales": "output.scale",
   }
   sd = {}
   for k, v in weights.items():
-    if ".rotary_emb." in k:
-      continue
+    if ".rotary_emb." in k: continue
     v = v.to(Device.DEFAULT)
     if "model.layers" in k:
-      layer_num = int(k.split(".")[2])
-      if shard.start_layer <= layer_num <= shard.end_layer:
-        k = f"model.layers.{layer_num - shard.start_layer}." + ".".join(k.split(".")[3:])
-      else:
-        continue
-
       if "q_proj" in k:
         v = permute(v, n_heads)
       elif "k_proj" in k:
@@ -286,10 +215,9 @@ def convert_from_huggingface(weights: Dict[str, Tensor], model: Transformer, n_h
     sd[keymap[k]] = v
   return sd
 
-
-def fix_bf16(weights: Dict[Any, Tensor]):
+def fix_bf16(weights:Dict[Any, Tensor]):
   if getenv("SUPPORT_BF16", 1):
     # TODO: without casting to float16, 70B llama OOM on tinybox.
-    return {k: v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
+    return {k:v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}
   # TODO: check if device supports bf16
-  return {k: v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
+  return {k:v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}

+ 38 - 0
exo/inference/tinygrad/tinygrad_helpers.py

@@ -0,0 +1,38 @@
+from tinygrad.nn.state import safe_load, torch_load
+from tinygrad import Tensor
+from pathlib import Path
+import json
+from typing import List
+from exo.inference.shard import Shard
+from exo.helpers import DEBUG
+
+# **** helper functions ****
+def concat_weights(models, device=None):
+  def convert(name) -> Tensor:
+    disk_tensors: List[Tensor] = [model[name] for model in models]
+    if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1:
+      return disk_tensors[0].to(device=device)
+    axis = 1 if name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0
+    lazy_tensors = [data.to(device=device) for data in disk_tensors]
+    return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
+  return {name: convert(name) for name in {name: None for model in models for name in model}}
+
+def load(fn:str, shard: Shard):
+  if fn.endswith('.index.json'):
+    with open(fn) as fp: weight_map = json.load(fp)['weight_map']
+    parts = {}
+    filtered_weight_map = {}
+    for k, n in weight_map.items():
+      if k.startswith("model.layers."):
+        layer_num = int(k.split('.')[2])
+        if layer_num < shard.start_layer or layer_num > shard.end_layer:
+          continue
+
+      parts[n] = load(str(Path(fn).parent / Path(n).name), shard)
+      filtered_weight_map[k] = n
+    if DEBUG >= 2: print(f"Excluded model param keys for {shard=}: {set(weight_map.keys()) - set(filtered_weight_map.keys())}")
+    return {k: parts[n][k] for k, n in filtered_weight_map.items()}
+  elif fn.endswith(".safetensors"):
+    return safe_load(fn)
+  else:
+    return torch_load(fn)