浏览代码

refactor tinygrad, only load necessary layers for each shard fixes #128, enable JIT (much faster), prefill all layers not just the first shard fixes #12, use new ShardDownloader for more robust, parallel downloads

Alex Cheema 1 年之前
父节点
当前提交
2be446546f

+ 4 - 3
exo/inference/test_inference_engine.py

@@ -11,6 +11,7 @@ import numpy as np
 async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str):
 async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str):
   prompt = "In a single word only, what is the last name of the current president of the USA?"
   prompt = "In a single word only, what is the last name of the current president of the USA?"
   resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
   resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
+
   next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor(
   next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor(
     "A",
     "A",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
     shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
@@ -52,7 +53,7 @@ asyncio.run(
 
 
 # TODO: Need more memory or a smaller model
 # TODO: Need more memory or a smaller model
 # asyncio.run(test_inference_engine(
 # asyncio.run(test_inference_engine(
-#     TinygradDynamicShardInferenceEngine(),
-#     TinygradDynamicShardInferenceEngine(),
-#     "mlx-community/Meta-Llama-3-8B-Instruct",
+#     TinygradDynamicShardInferenceEngine(HFShardDownloader()),
+#     TinygradDynamicShardInferenceEngine(HFShardDownloader()),
+#     "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R",
 # ))
 # ))

+ 48 - 150
exo/inference/tinygrad/inference.py

@@ -1,199 +1,97 @@
 from pathlib import Path
 from pathlib import Path
-from typing import List, Optional
+from typing import List
 import json
 import json
 from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
 from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
-from tinygrad.nn.state import safe_load, torch_load, load_state_dict
-from tinygrad import Tensor, nn, Context, GlobalCounters
-from tinygrad.helpers import tqdm
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
+from tinygrad.nn.state import safe_load, torch_load, load_state_dict
+from tinygrad import Tensor, dtypes, nn, Context
+from transformers import AutoTokenizer
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.inference_engine import InferenceEngine
+from typing import Optional, Tuple
 import numpy as np
 import numpy as np
+from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
 from exo.download.shard_download import ShardDownloader
 from exo.download.shard_download import ShardDownloader
 
 
+Tensor.no_grad = True
+# default settings
+TEMPERATURE = 0.85
+TOP_K = 25
+TOP_P = 0.9
+ALPHA_F = 0.1
+ALPHA_P = 0.0
 MODEL_PARAMS = {
 MODEL_PARAMS = {
   "8B": {
   "8B": {
-    "args": {
-      "dim": 4096,
-      "n_heads": 32,
-      "n_kv_heads": 8,
-      "n_layers": 32,
-      "norm_eps": 1e-5,
-      "rope_theta": 500000,
-      "vocab_size": 128256,
-      "hidden_dim": 14336,
-    },
-    "files": 1,
+    "args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336},
+    "files": 1
   },
   },
   "70B": {
   "70B": {
-    "args": {
-      "dim": 8192,
-      "n_heads": 64,
-      "n_kv_heads": 8,
-      "n_layers": 80,
-      "norm_eps": 1e-5,
-      "rope_theta": 500000,
-      "vocab_size": 128256,
-      "hidden_dim": 28672,
-    },
-    "files": 8,
-  },
+    "args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256,  "hidden_dim": 28672},
+    "files": 8
+  }
 }
 }
 
 
-
-
-# **** helper functions ****
-
-def concat_weights(models, device=None):
-  def convert(name) -> Tensor:
-    disk_tensors: List[Tensor] = [model[name] for model in models]
-    if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1:
-      return disk_tensors[0].to(device=device)
-    axis = 1 if name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0
-    lazy_tensors = [data.to(device=device) for data in disk_tensors]
-    return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
-
-  return {name: convert(name) for name in {name: None for model in models for name in model}}
-
-
-def load(fn: str):
-  if fn.endswith(".index.json"):
-    with open(fn) as fp:
-      weight_map = json.load(fp)["weight_map"]
-    parts = {n: load(str(Path(fn).parent / Path(n).name)) for n in set(weight_map.values())}
-    return {k: parts[n][k] for k, n in weight_map.items()}
-  elif fn.endswith(".safetensors"):
-    return safe_load(fn)
-  else:
-    return torch_load(fn)
-
-
-def build_transformer(model_path: Path, shard: Shard, model_size="8B", quantize=None, device=None):
+def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=None):
   # build model
   # build model
   linear = nn.Linear
   linear = nn.Linear
   with Context(THREEFRY=0):
   with Context(THREEFRY=0):
-    model = Transformer(**MODEL_PARAMS[model_size]["args"], shard=shard, linear=linear, max_context=8192, jit=False)
+    model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True, shard=shard)
 
 
   # load weights
   # load weights
   if model_path.is_dir():
   if model_path.is_dir():
-    if (model_path / "model.safetensors.index.json").exists():
-      weights = load(str(model_path / "model.safetensors.index.json"))
-    elif (model_path / "model.safetensors").exists():
-      weights = load(str(model_path / "model.safetensors"))
-    else:
-      weights = concat_weights(
-        [load(str(model_path / f"consolidated.{i:02d}.pth")) for i in range(MODEL_PARAMS[model_size]["files"])],
-        device[0] if isinstance(device, tuple) else device,
-      )
+    if (model_path / "model.safetensors.index.json").exists(): weights = load(str(model_path / "model.safetensors.index.json"), shard)
+    elif (model_path / "model.safetensors").exists(): weights = load(str(model_path / "model.safetensors"), shard)
+    else: weights = concat_weights([load(str(model_path / f"consolidated.{i:02d}.pth"), shard) for i in range(MODEL_PARAMS[model_size]["files"])], device[0] if isinstance(device, tuple) else device)
   else:
   else:
-    weights = load(str(model_path))
-
+    weights = load(str(model_path), shard)
   if "model.embed_tokens.weight" in weights:
   if "model.embed_tokens.weight" in weights:
-    weights = convert_from_huggingface(
-      weights,
-      model,
-      MODEL_PARAMS[model_size]["args"]["n_heads"],
-      MODEL_PARAMS[model_size]["args"]["n_kv_heads"],
-      shard=shard,
-    )
+    weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"])
   weights = fix_bf16(weights)
   weights = fix_bf16(weights)
 
 
   with Context(BEAM=0):
   with Context(BEAM=0):
-    # quantize
-    if quantize is not None:
-      weights = linear.quantize(weights, device)
-      for _, v in weights.items():
-        v.realize()
-
-    # shard
-    if isinstance(device, tuple):
-      for k, v in nn.state.get_state_dict(model).items():
-        if "scale" in k:
-          v.shard_(device, axis=None)  # from quantized
-        elif ".attention." in k:
-          v.shard_(device, axis=-1)
-        elif ".feed_forward.w1." in k:
-          v.shard_(device, axis=0)
-        elif ".feed_forward.w3." in k:
-          v.shard_(device, axis=0)
-        elif ".feed_forward." in k:
-          v.shard_(device, axis=-1)
-        elif "tok_embeddings.weight" in k:
-          v.shard_(device, axis=0)
-        elif "output.weight" in k:
-          v.shard_(device, axis=0)
-        else:
-          v.shard_(device, axis=None)
-
     # replace weights in model
     # replace weights in model
-    load_state_dict(model, weights, strict=False, consume=True)
+    load_state_dict(model, weights, strict=False, consume=False) # consume=True
   return model
   return model
 
 
-
-# default settings
-TEMPERATURE = 0  # 0.85
-TOP_K = 25
-TOP_P = 0.9
-ALPHA_F = 0.1
-ALPHA_P = 0.0
-
-
-def prefill(model, toks, start_pos=0):
-  # prefill the model
-  for tok in tqdm(toks):
-    GlobalCounters.reset()
-    model(Tensor([[tok]]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).realize()
-    start_pos += 1
-  return start_pos
-
-
 class TinygradDynamicShardInferenceEngine(InferenceEngine):
 class TinygradDynamicShardInferenceEngine(InferenceEngine):
   def __init__(self, shard_downloader: ShardDownloader):
   def __init__(self, shard_downloader: ShardDownloader):
     self.shard = None
     self.shard = None
     self.shard_downloader = shard_downloader
     self.shard_downloader = shard_downloader
 
 
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
-    # TODO: we need to refactor models/llamaa to handle per-request-kv-cache. right now it's shared between requests.
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
-    start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0
+    start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
+    n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
 
 
     toks = self.tokenizer.encode(prompt)
     toks = self.tokenizer.encode(prompt)
-    start_pos = prefill(self.model, toks[:-1], start_pos=start_pos)
-    last_tok = toks[-1]
+    h = self.model(Tensor([toks]), start_pos, TEMPERATURE)
 
 
-    output_data = np.array([self.model(Tensor([[last_tok]]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist()])
-    if output_data.size == 1:
-      start_pos += 1
-
-    return (
-      output_data,
-      json.dumps({"start_pos": start_pos}),
-      output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id],
-    )
+    if h.shape == (1,):
+      start_pos += len(toks)
+      n_captured_toks = 1
+      return np.array([[h.item()]]), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), h.item() == self.tokenizer.eos_token_id
+    else:
+      n_captured_toks += len(toks)
+      return h.numpy(), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks + 1}), False
 
 
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
-    start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0
+    start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
+    n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
 
 
-    output_data: np.ndarray = np.array([self.model(Tensor([input_data]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist()])
-    if output_data.size == 1:
-      start_pos += 1
+    h = self.model(Tensor(input_data), start_pos, TEMPERATURE).realize()
 
 
-    return (
-      output_data,
-      json.dumps({"start_pos": start_pos}),
-      output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id],
-    )
+    if h.shape == (1,):
+      start_pos += n_captured_toks
+      n_captured_toks = 1
+      return np.array([[h.item()]]), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), h.item() == self.tokenizer.eos_token_id
+    else:
+      return h.numpy(), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), False
 
 
   async def ensure_shard(self, shard: Shard):
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:
     if self.shard == shard:
       return
       return
 
 
     model_path = await self.shard_downloader.ensure_shard(shard)
     model_path = await self.shard_downloader.ensure_shard(shard)
-    print(f"{model_path=}")
-    model = build_transformer(model_path, shard=shard, model_size="8B" if "8b" in shard.model_id else "70B" if "70b" in shard.model_id else "8B")
-    from transformers import AutoTokenizer
-    tokenizer = AutoTokenizer.from_pretrained(str((model_path if model_path.is_dir() else model_path.parent)))
-
+    self.model = build_transformer(model_path, shard, model_size="8B")
+    self.tokenizer = AutoTokenizer.from_pretrained(str((model_path if model_path.is_dir() else model_path.parent)))
     self.shard = shard
     self.shard = shard
-    self.model = model
-    self.tokenizer = tokenizer

+ 53 - 125
exo/inference/tinygrad/models/llama.py

@@ -1,26 +1,22 @@
 from typing import Tuple, Union, Optional, Dict, Any
 from typing import Tuple, Union, Optional, Dict, Any
 from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
 from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
 from tinygrad.helpers import getenv
 from tinygrad.helpers import getenv
-from exo.inference.shard import Shard
-
 
 
 # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
 # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
 def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half) -> Tensor:
 def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half) -> Tensor:
-  freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[: (dim // 2)] / dim))
+  freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
   freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
   freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
   # TODO: move dtype outside this
   # TODO: move dtype outside this
-  return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim // 2, 2)
-
+  return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim//2, 2)
 
 
 # (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
 # (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
 def complex_mult(A, c, d):
 def complex_mult(A, c, d):
-  a, b = A[..., 0:1], A[..., 1:2]
-  ro = a * c - b * d
-  co = a * d + b * c
+  a,b = A[..., 0:1], A[..., 1:2]
+  ro = a*c - b*d
+  co = a*d + b*c
   return ro.cat(co, dim=-1)
   return ro.cat(co, dim=-1)
 
 
-
-def apply_rotary_emb(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> Tuple[Tensor, Tensor]:
+def apply_rotary_emb(xq:Tensor, xk:Tensor, freqs_cis:Tensor) -> Tuple[Tensor, Tensor]:
   assert freqs_cis.shape[1] == xq.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
   assert freqs_cis.shape[1] == xq.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
   xq = xq.reshape(*xq.shape[0:-1], -1, 2)
   xq = xq.reshape(*xq.shape[0:-1], -1, 2)
   xk = xk.reshape(*xk.shape[0:-1], -1, 2)
   xk = xk.reshape(*xk.shape[0:-1], -1, 2)
@@ -30,19 +26,16 @@ def apply_rotary_emb(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> Tuple[Tensor,
   xk_out = complex_mult(xk, c, d)
   xk_out = complex_mult(xk, c, d)
   return xq_out.flatten(3), xk_out.flatten(3)
   return xq_out.flatten(3), xk_out.flatten(3)
 
 
-
-def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
+def repeat_kv(x:Tensor, n_rep:int) -> Tensor:
   bs, seqlen, n_kv_heads, head_dim = x.shape
   bs, seqlen, n_kv_heads, head_dim = x.shape
-  if n_rep == 1:
-    return x
+  if n_rep == 1: return x
   # NOTE: this is different from x.repeat((1, 1, n_rep, 1))
   # NOTE: this is different from x.repeat((1, 1, n_rep, 1))
   return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
   return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
 
 
-
 class Attention:
 class Attention:
   def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
   def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
     self.n_heads = n_heads
     self.n_heads = n_heads
-    self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads  # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
+    self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
     self.head_dim = dim // n_heads
     self.head_dim = dim // n_heads
     self.n_rep = self.n_heads // self.n_kv_heads
     self.n_rep = self.n_heads // self.n_kv_heads
     self.max_context = max_context
     self.max_context = max_context
@@ -52,8 +45,14 @@ class Attention:
     self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
     self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
     self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
     self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
 
 
-  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]) -> Tensor:
-    xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
+  def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]) -> Tensor:
+    if getenv("WQKV"):
+      if not hasattr(self, 'wqkv'): self.wqkv = Tensor.cat(self.wq.weight, self.wk.weight, self.wv.weight)
+      xqkv = x @ self.wqkv.T
+      xq, xk, xv = xqkv.split([self.wq.weight.shape[0], self.wk.weight.shape[0], self.wv.weight.shape[0]], dim=2)
+    else:
+      xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
+
     xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
     xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
     xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
     xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
     xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
     xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
@@ -66,14 +65,14 @@ class Attention:
       self.cache_kv = Tensor.zeros(2, bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype).contiguous().realize()
       self.cache_kv = Tensor.zeros(2, bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype).contiguous().realize()
       if isinstance(x.device, tuple):
       if isinstance(x.device, tuple):
         # TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded
         # TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded
-        self.cache_kv.shard_((x.device), axis=None).realize()
+        self.cache_kv.shard_((x.device), axis=3 if getenv("SHARD_KVCACHE") else None).realize()
 
 
     # update the cache
     # update the cache
     assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}"
     assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}"
-    self.cache_kv.shrink((None, None, (start_pos, start_pos + seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
+    self.cache_kv.shrink((None, None, (start_pos, start_pos+seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
 
 
-    keys = self.cache_kv[0].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xk
-    values = self.cache_kv[1].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xv
+    keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xk
+    values = self.cache_kv[1].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xv
 
 
     keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
     keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
     xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
     xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
@@ -81,39 +80,26 @@ class Attention:
     attn = attn.reshape(bsz, seqlen, -1)
     attn = attn.reshape(bsz, seqlen, -1)
     return self.wo(attn)
     return self.wo(attn)
 
 
-
 class FeedForward:
 class FeedForward:
-  def __init__(self, dim: int, hidden_dim: int, linear=nn.Linear):
+  def __init__(self, dim:int, hidden_dim:int, linear=nn.Linear):
     self.w1 = linear(dim, hidden_dim, bias=False)
     self.w1 = linear(dim, hidden_dim, bias=False)
     self.w2 = linear(hidden_dim, dim, bias=False)
     self.w2 = linear(hidden_dim, dim, bias=False)
-    self.w3 = linear(dim, hidden_dim, bias=False)  # the gate in Gated Linear Unit
-
-  def __call__(self, x: Tensor) -> Tensor:
-    return self.w2(self.w1(x).silu() * self.w3(x))  # SwiGLU [arxiv/2002.05202, eq (5)]
+    self.w3 = linear(dim, hidden_dim, bias=False) # the gate in Gated Linear Unit
 
 
+  def __call__(self, x:Tensor) -> Tensor:
+    return self.w2(self.w1(x).silu() * self.w3(x)) # SwiGLU [arxiv/2002.05202, eq (5)]
 
 
 class TransformerBlock:
 class TransformerBlock:
-  def __init__(
-    self,
-    dim: int,
-    hidden_dim: int,
-    n_heads: int,
-    n_kv_heads: int,
-    norm_eps: float,
-    max_context: int,
-    linear=nn.Linear,
-    feed_forward=FeedForward,
-  ):
+  def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int, linear=nn.Linear, feed_forward=FeedForward):
     self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear)
     self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear)
     self.feed_forward = feed_forward(dim, hidden_dim, linear)
     self.feed_forward = feed_forward(dim, hidden_dim, linear)
     self.attention_norm = nn.RMSNorm(dim, norm_eps)
     self.attention_norm = nn.RMSNorm(dim, norm_eps)
     self.ffn_norm = nn.RMSNorm(dim, norm_eps)
     self.ffn_norm = nn.RMSNorm(dim, norm_eps)
 
 
-  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]):
+  def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]):
     h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
     h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
     return (h + self.feed_forward(self.ffn_norm(h))).contiguous()
     return (h + self.feed_forward(self.ffn_norm(h))).contiguous()
 
 
-
 # standard openai sampling
 # standard openai sampling
 def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
 def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
   assert logits.ndim == 1, "only works on 1d tensors"
   assert logits.ndim == 1, "only works on 1d tensors"
@@ -121,8 +107,7 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
   assert 0 <= k <= logits.numel(), "k must be between 0 and numel"
   assert 0 <= k <= logits.numel(), "k must be between 0 and numel"
 
 
   # if temperature is very low just use argmax
   # if temperature is very low just use argmax
-  if temp < 1e-6:
-    return logits.argmax()
+  if temp < 1e-6: return logits.argmax()
 
 
   # alpha sampling
   # alpha sampling
   if af or ap:
   if af or ap:
@@ -136,16 +121,10 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
   # softmax
   # softmax
   t = (logits / temp).softmax()
   t = (logits / temp).softmax()
 
 
-  counter, counter2 = (
-    Tensor.arange(t.numel(), device=logits.device).contiguous(),
-    Tensor.arange(t.numel() - 1, -1, -1, device=logits.device).contiguous(),
-  )
+  counter, counter2 = Tensor.arange(t.numel(), device=logits.device).contiguous(), Tensor.arange(t.numel() - 1, -1, -1, device=logits.device).contiguous()
   # top k
   # top k
   if k:
   if k:
-    output, output_indices = (
-      Tensor.zeros(k, device=logits.device).contiguous(),
-      Tensor.zeros(k, device=logits.device, dtype=dtypes.int32).contiguous(),
-    )
+    output, output_indices = Tensor.zeros(k, device=logits.device).contiguous(), Tensor.zeros(k, device=logits.device, dtype=dtypes.int32).contiguous()
     for i in range(k):
     for i in range(k):
       t_argmax = (t.numel() - ((t == (t_max := t.max())) * counter2).max() - 1).cast(dtypes.default_int)
       t_argmax = (t.numel() - ((t == (t_max := t.max())) * counter2).max() - 1).cast(dtypes.default_int)
       output = output + t_max.unsqueeze(0).pad(((i, k - i - 1),))
       output = output + t_max.unsqueeze(0).pad(((i, k - i - 1),))
@@ -170,84 +149,48 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
 
 
   return output_token
   return output_token
 
 
+from exo.inference.shard import Shard
 
 
 class Transformer:
 class Transformer:
-  def __init__(
-    self,
-    dim: int,
-    hidden_dim: int,
-    n_heads: int,
-    n_layers: int,
-    norm_eps: float,
-    vocab_size,
-    shard: Shard,
-    linear=nn.Linear,
-    n_kv_heads=None,
-    rope_theta=10000,
-    max_context=1024,
-    jit=True,
-    feed_forward=FeedForward,
-  ):
-    self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(shard.end_layer - shard.start_layer + 1)]
+  def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, shard: Shard=None, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward):
+    self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
     self.norm = nn.RMSNorm(dim, norm_eps)
     self.norm = nn.RMSNorm(dim, norm_eps)
     self.tok_embeddings = nn.Embedding(vocab_size, dim)
     self.tok_embeddings = nn.Embedding(vocab_size, dim)
     self.output = nn.Linear(dim, vocab_size, bias=False)
     self.output = nn.Linear(dim, vocab_size, bias=False)
     self.max_context = max_context
     self.max_context = max_context
-    self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context * 2, rope_theta)
+    self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context * 2, rope_theta).contiguous()
     self.forward_jit = TinyJit(self.forward) if jit else None
     self.forward_jit = TinyJit(self.forward) if jit else None
     self.shard = shard
     self.shard = shard
 
 
-  def forward(
-    self,
-    h: Tensor,
-    start_pos: Union[Variable, int],
-    temperature: float,
-    top_k: int,
-    top_p: float,
-    alpha_f: float,
-    alpha_p: float,
-  ):
-    seqlen = h.shape[1]
-    freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos + seqlen), None, None, None))
+  def forward(self, x:Tensor, start_pos:Union[Variable,int], temperature:float, top_k:int, top_p:float, alpha_f:float, alpha_p:float):
+    seqlen = x.shape[1]
+    freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos+seqlen),None,None,None))
+    mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-100000000"), dtype=x.dtype, device=x.device).triu(start_pos+1).realize() if seqlen > 1 else None
 
 
     if self.shard.is_first_layer():
     if self.shard.is_first_layer():
-      h = self.tok_embeddings(h)
-    mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos + 1).realize() if seqlen > 1 else None
+      h = self.tok_embeddings(x)
+    else:
+      h = x
 
 
-    for layer in self.layers:
+    for i in range(self.shard.start_layer, self.shard.end_layer + 1):
+      layer = self.layers[i]
       h = layer(h, start_pos, freqs_cis, mask)
       h = layer(h, start_pos, freqs_cis, mask)
 
 
     if self.shard.is_last_layer():
     if self.shard.is_last_layer():
       logits = self.output(self.norm(h)).float()[:, -1, :]
       logits = self.output(self.norm(h)).float()[:, -1, :]
       return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p).realize()
       return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p).realize()
     else:
     else:
-      return h.realize()
-
-  def __call__(
-    self,
-    tokens: Tensor,
-    start_pos: Variable,
-    temperature: float = 0.0,
-    top_k: int = 0,
-    top_p: float = 0.8,
-    alpha_f: float = 0.0,
-    alpha_p: float = 0.0,
-  ):
+      return h
+
+  def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0, top_k:int=0, top_p:float=0.8, alpha_f:float=0.0, alpha_p:float=0.0):
     # TODO: better way to handle the first call v.s. the rest?
     # TODO: better way to handle the first call v.s. the rest?
-    # if tokens.shape[0:2] == (1,1) and self.forward_jit is not None:
-    #   return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
+    if tokens.shape[0:2] == (1,1) and self.forward_jit is not None:
+      return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
     return self.forward(tokens, start_pos, temperature, top_k, top_p, alpha_f, alpha_p)
     return self.forward(tokens, start_pos, temperature, top_k, top_p, alpha_f, alpha_p)
 
 
-  def reset(self):
-    for layer in self.layers:
-      if hasattr(layer.attention, "cache_kv"):
-        layer.attention.cache_kv = layer.attention.cache_kv.zeros_like()
-
-
 # *** helpers ***
 # *** helpers ***
 
 
-
-def convert_from_huggingface(weights: Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int, shard: Shard):
+def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int):
   def permute(v: Tensor, n_heads: int):
   def permute(v: Tensor, n_heads: int):
     return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
     return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
 
 
@@ -255,30 +198,16 @@ def convert_from_huggingface(weights: Dict[str, Tensor], model: Transformer, n_h
     "model.embed_tokens.weight": "tok_embeddings.weight",
     "model.embed_tokens.weight": "tok_embeddings.weight",
     **{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))},
     **{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))},
     **{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
     **{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
-    **{f"model.layers.{l}.self_attn.{x}_proj.biases": f"layers.{l}.attention.w{x}.bias" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
-    **{f"model.layers.{l}.self_attn.{x}_proj.scales": f"layers.{l}.attention.w{x}.scale" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
     **{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(len(model.layers))},
     **{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(len(model.layers))},
-    **{f"model.layers.{l}.post_attention_layernorm.biases": f"layers.{l}.ffn_norm.bias" for l in range(len(model.layers))},
     **{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
     **{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
-    **{f"model.layers.{l}.mlp.{x}_proj.biases": f"layers.{l}.feed_forward.w{y}.bias" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
-    **{f"model.layers.{l}.mlp.{x}_proj.scales": f"layers.{l}.feed_forward.w{y}.scale" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
     "model.norm.weight": "norm.weight",
     "model.norm.weight": "norm.weight",
     "lm_head.weight": "output.weight",
     "lm_head.weight": "output.weight",
-    "lm_head.biases": "output.bias",
-    "lm_head.scales": "output.scale",
   }
   }
   sd = {}
   sd = {}
   for k, v in weights.items():
   for k, v in weights.items():
-    if ".rotary_emb." in k:
-      continue
+    if ".rotary_emb." in k: continue
     v = v.to(Device.DEFAULT)
     v = v.to(Device.DEFAULT)
     if "model.layers" in k:
     if "model.layers" in k:
-      layer_num = int(k.split(".")[2])
-      if shard.start_layer <= layer_num <= shard.end_layer:
-        k = f"model.layers.{layer_num - shard.start_layer}." + ".".join(k.split(".")[3:])
-      else:
-        continue
-
       if "q_proj" in k:
       if "q_proj" in k:
         v = permute(v, n_heads)
         v = permute(v, n_heads)
       elif "k_proj" in k:
       elif "k_proj" in k:
@@ -286,10 +215,9 @@ def convert_from_huggingface(weights: Dict[str, Tensor], model: Transformer, n_h
     sd[keymap[k]] = v
     sd[keymap[k]] = v
   return sd
   return sd
 
 
-
-def fix_bf16(weights: Dict[Any, Tensor]):
+def fix_bf16(weights:Dict[Any, Tensor]):
   if getenv("SUPPORT_BF16", 1):
   if getenv("SUPPORT_BF16", 1):
     # TODO: without casting to float16, 70B llama OOM on tinybox.
     # TODO: without casting to float16, 70B llama OOM on tinybox.
-    return {k: v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
+    return {k:v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}
   # TODO: check if device supports bf16
   # TODO: check if device supports bf16
-  return {k: v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
+  return {k:v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}

+ 38 - 0
exo/inference/tinygrad/tinygrad_helpers.py

@@ -0,0 +1,38 @@
+from tinygrad.nn.state import safe_load, torch_load
+from tinygrad import Tensor
+from pathlib import Path
+import json
+from typing import List
+from exo.inference.shard import Shard
+from exo.helpers import DEBUG
+
+# **** helper functions ****
+def concat_weights(models, device=None):
+  def convert(name) -> Tensor:
+    disk_tensors: List[Tensor] = [model[name] for model in models]
+    if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1:
+      return disk_tensors[0].to(device=device)
+    axis = 1 if name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0
+    lazy_tensors = [data.to(device=device) for data in disk_tensors]
+    return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
+  return {name: convert(name) for name in {name: None for model in models for name in model}}
+
+def load(fn:str, shard: Shard):
+  if fn.endswith('.index.json'):
+    with open(fn) as fp: weight_map = json.load(fp)['weight_map']
+    parts = {}
+    filtered_weight_map = {}
+    for k, n in weight_map.items():
+      if k.startswith("model.layers."):
+        layer_num = int(k.split('.')[2])
+        if layer_num < shard.start_layer or layer_num > shard.end_layer:
+          continue
+
+      parts[n] = load(str(Path(fn).parent / Path(n).name), shard)
+      filtered_weight_map[k] = n
+    if DEBUG >= 2: print(f"Excluded model param keys for {shard=}: {set(weight_map.keys()) - set(filtered_weight_map.keys())}")
+    return {k: parts[n][k] for k, n in filtered_weight_map.items()}
+  elif fn.endswith(".safetensors"):
+    return safe_load(fn)
+  else:
+    return torch_load(fn)