浏览代码

Proper sharding in tinygrad

Now tinygrad will build a model from just the shard just like mlx does
Nel Nibcord 4 月之前
父节点
当前提交
b1397b49be

+ 4 - 2
exo/inference/tinygrad/inference.py

@@ -1,7 +1,7 @@
 from pathlib import Path
 import json
 import os
-from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16, sample_logits
+from exo.inference.tinygrad.models.llama import Transformer, TransformerShard, convert_from_huggingface, fix_bf16, sample_logits
 from exo.inference.shard import Shard
 from exo.inference.tokenizers import resolve_tokenizer
 from tinygrad.nn.state import load_state_dict
@@ -57,6 +57,8 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=No
   with Context(BEAM=0):
     # replace weights in model
     load_state_dict(model, weights, strict=False, consume=False)  # consume=True
+    model = TransformerShard(shard, model)
+
   return model
 
 class TinygradDynamicShardInferenceEngine(InferenceEngine):
@@ -70,7 +72,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     if request_id not in self.states:
       if len(self.states) >= max_states:
         self.states.popitem(last=False)
-      self.states[request_id] = make_prompt_state(x, self.model, self.shard)
+      self.states[request_id] = make_prompt_state(x, self.model)
     else:
       self.states.move_to_end(request_id)
     state = self.states[request_id]

+ 39 - 0
exo/inference/tinygrad/models/llama.py

@@ -235,7 +235,46 @@ class Transformer:
     h = self.embed(x)
     return self.forward(h, start_pos, cache=cache)
 
+class TransformerShard:
+  def __init__(
+    self,
+    shard: Shard,
+    base,
+    jit: bool = True,
+  ):
+    shardrange = range(shard.start_layer, shard.end_layer + 1)
+    self.layers = [layer for layer, n in zip(base.layers, range(shard.n_layers)) if n in shardrange]
+    self.norm = base.norm 
+    self.tok_embeddings = base.tok_embeddings
+    self.embed = (lambda x: self.tok_embeddings(x)) if shard.is_first_layer() else (lambda x: x)
+    self.output = base.output
+    self.post = (lambda x: self.output(x)) if shard.is_last_layer() else (lambda x: x)
+    self.max_context = base.max_context
+    self.null_cache = [None for _ in shardrange] 
+    self.freqs_cis = base.freqs_cis
+    self.forward_jit = TinyJit(self.forward_base) if jit else None
+
+  def forward_base(self, x: Tensor, start_pos: Union[Variable, int], cache):
+    seqlen = x.shape[1]
+    freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos + seqlen), None, None, None))
+    mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-100000000"), dtype=x.dtype, device=x.device).triu(start_pos + 1).realize() if seqlen > 1 else None
+
+    for layer, c in zip(self.layers, cache):
+      x = layer(x, start_pos, freqs_cis, mask, cache=c)
 
+    out = self.post(x)
+    return out
+
+  def forward(self, x: Tensor, start_pos: int, cache: Optional[List[Tensor]] = None):
+    if x.shape[0:2] == (1, 1) and self.forward_jit is not None and start_pos != 0:
+      return self.forward_jit(x, Variable("start_pos", 1, self.max_context).bind(start_pos), cache=cache)
+    return self.forward_base(x, start_pos, cache=cache)
+
+  def __call__(self, x: Tensor, start_pos: Variable, cache: Optional[List[Tensor]] = None):
+    # TODO: better way to handle the first call v.s. the rest?
+    h = self.embed(x)
+    return self.forward(h, start_pos, cache=self.null_cache if cache is None else cache)
+      
 # *** helpers ***
 
 

+ 4 - 4
exo/inference/tinygrad/stateful_model.py

@@ -2,8 +2,8 @@ from tinygrad import Tensor, Variable
 from collections import OrderedDict
 from typing import List, Optional
 
-def create_kv_cache(x: Tensor, max_context: int, n_kv_heads: int, head_dim: int):
-  cache_kv = Tensor.zeros(2, x.shape[0], max_context, n_kv_heads, head_dim, dtype=x.dtype).contiguous().realize()
+def create_kv_cache(x: Tensor, layer):
+  cache_kv = Tensor.zeros(2, x.shape[0], layer.max_context, layer.n_kv_heads, layer.head_dim, dtype=x.dtype).contiguous().realize()
   if isinstance(x.device, tuple):
     # TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded
     cache_kv.shard_((x.device), axis=3 if getenv("SHARD_KVCACHE") else None).realize()
@@ -16,7 +16,7 @@ class ModelState:
     self.cache = cache
     self.start = start
 
-def make_prompt_state(x, model, shard):
-  cache = [create_kv_cache(x, model.layers[i].attention.max_context, model.layers[i].attention.n_kv_heads, model.layers[i].attention.head_dim) for i in range(shard.start_layer, shard.end_layer + 1)]
+def make_prompt_state(x: Tensor, model):
+  cache = [create_kv_cache(x, l.attention) for l in model.layers]
 
   return ModelState(cache)