Browse Source

Implemented per-request caching in tinygrad

Nel Nibcord 8 months ago
parent
commit
8205a5aebc
2 changed files with 49 additions and 30 deletions
  1. 1 2
      exo/inference/tinygrad/inference.py
  2. 48 28
      exo/inference/tinygrad/models/llama.py

+ 1 - 2
exo/inference/tinygrad/inference.py

@@ -58,7 +58,6 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=No
     load_state_dict(model, weights, strict=False, consume=False)  # consume=True
     load_state_dict(model, weights, strict=False, consume=False)  # consume=True
   return model
   return model
 
 
-
 class TinygradDynamicShardInferenceEngine(InferenceEngine):
 class TinygradDynamicShardInferenceEngine(InferenceEngine):
   def __init__(self, shard_downloader: ShardDownloader):
   def __init__(self, shard_downloader: ShardDownloader):
     self.shard = None
     self.shard = None
@@ -85,7 +84,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> np.ndarray:
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> np.ndarray:
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
     start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
     start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
-    output_data = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), start_pos).realize())
+    output_data = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), start_pos, request_id).realize())
     return output_data.numpy()
     return output_data.numpy()
 
 
   async def ensure_shard(self, shard: Shard):
   async def ensure_shard(self, shard: Shard):

+ 48 - 28
exo/inference/tinygrad/models/llama.py

@@ -1,6 +1,7 @@
 from typing import Tuple, Union, Optional, Dict, Any
 from typing import Tuple, Union, Optional, Dict, Any
 from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
 from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
 from tinygrad.helpers import getenv
 from tinygrad.helpers import getenv
+from collections import OrderedDict
 
 
 
 
 # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
 # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
@@ -47,6 +48,12 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
   # NOTE: this is different from x.repeat((1, 1, n_rep, 1))
   # NOTE: this is different from x.repeat((1, 1, n_rep, 1))
   return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads*n_rep, head_dim)
   return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads*n_rep, head_dim)
 
 
+def create_kv_cache(x: Tensor, max_context: int, n_kv_heads: int, head_dim: int):
+  cache_kv = Tensor.zeros(2, x.shape[0], max_context, n_kv_heads, head_dim, dtype=x.dtype).contiguous().realize()
+  if isinstance(x.device, tuple):
+    # TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded
+    cache_kv.shard_((x.device), axis=3 if getenv("SHARD_KVCACHE") else None).realize()
+  return cache_kv.realize()
 
 
 class Attention:
 class Attention:
   def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
   def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
@@ -61,7 +68,7 @@ class Attention:
     self.wv = linear(dim, self.n_kv_heads*self.head_dim, bias=False)
     self.wv = linear(dim, self.n_kv_heads*self.head_dim, bias=False)
     self.wo = linear(self.n_heads*self.head_dim, dim, bias=False)
     self.wo = linear(self.n_heads*self.head_dim, dim, bias=False)
 
 
-  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]) -> Tensor:
+  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor], cache: Optional[Tensor]=None) -> Tensor:
     if getenv("WQKV"):
     if getenv("WQKV"):
       if not hasattr(self, 'wqkv'): self.wqkv = Tensor.cat(self.wq.weight, self.wk.weight, self.wv.weight)
       if not hasattr(self, 'wqkv'): self.wqkv = Tensor.cat(self.wq.weight, self.wk.weight, self.wv.weight)
       xqkv = x @ self.wqkv.T
       xqkv = x @ self.wqkv.T
@@ -76,19 +83,16 @@ class Attention:
     xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
     xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
     bsz, seqlen, _, _ = xq.shape
     bsz, seqlen, _, _ = xq.shape
 
 
-    # create kv cache
-    if not hasattr(self, "cache_kv"):
-      self.cache_kv = Tensor.zeros(2, bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype).contiguous().realize()
-      if isinstance(x.device, tuple):
-        # TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded
-        self.cache_kv.shard_((x.device), axis=3 if getenv("SHARD_KVCACHE") else None).realize()
+    if cache is not None:
+      # update the cache
+      assert xk.dtype == xv.dtype == cache.dtype, f"{xk.dtype=}, {xv.dtype=}, {cache.dtype=}"
+      cache.shrink((None, None, (start_pos, start_pos + seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
 
 
-    # update the cache
-    assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}"
-    self.cache_kv.shrink((None, None, (start_pos, start_pos + seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
-
-    keys = self.cache_kv[0].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xk
-    values = self.cache_kv[1].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xv
+      keys = cache[0].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xk
+      values = cache[1].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xv
+    else:
+      keys = xk
+      values = xv
 
 
     keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
     keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
     xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
     xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
@@ -114,8 +118,8 @@ class TransformerBlock:
     self.attention_norm = nn.RMSNorm(dim, norm_eps)
     self.attention_norm = nn.RMSNorm(dim, norm_eps)
     self.ffn_norm = nn.RMSNorm(dim, norm_eps)
     self.ffn_norm = nn.RMSNorm(dim, norm_eps)
 
 
-  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]):
-    h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
+  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor], cache: Optional[Tensor]=None):
+    h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask, cache=cache)
     return (h + self.feed_forward(self.ffn_norm(h))).contiguous()
     return (h + self.feed_forward(self.ffn_norm(h))).contiguous()
 
 
 
 
@@ -189,7 +193,8 @@ class Transformer:
     jit=True,
     jit=True,
     feed_forward=FeedForward,
     feed_forward=FeedForward,
     rope_scaling: Optional[Dict[str, float]] = None,
     rope_scaling: Optional[Dict[str, float]] = None,
-    tie_word_embeddings=False
+    tie_word_embeddings=False,
+    max_caches=2,
   ):
   ):
     self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
     self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
     self.norm = nn.RMSNorm(dim, norm_eps)
     self.norm = nn.RMSNorm(dim, norm_eps)
@@ -201,19 +206,19 @@ class Transformer:
     self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context*2, rope_theta, rope_scaling=rope_scaling).contiguous()
     self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context*2, rope_theta, rope_scaling=rope_scaling).contiguous()
     self.forward_jit = TinyJit(self.forward) if jit else None
     self.forward_jit = TinyJit(self.forward) if jit else None
     self.shard = shard
     self.shard = shard
+    self.caches = OrderedDict()
+    self.max_caches = max_caches
 
 
-  def forward(self, x: Tensor, start_pos: Union[Variable, int]):
-    if self.shard.is_first_layer():
-      h = self.tok_embeddings(x)
-    else:
-      h = x
-    seqlen = h.shape[1]
+  def forward(self, x: Tensor, start_pos: Union[Variable, int], request_id: str):
+    seqlen = x.shape[1]
     freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos + seqlen), None, None, None))
     freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos + seqlen), None, None, None))
-    mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-100000000"), dtype=h.dtype, device=h.device).triu(start_pos + 1).realize() if seqlen > 1 else None
+    mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-100000000"), dtype=x.dtype, device=x.device).triu(start_pos + 1).realize() if seqlen > 1 else None
+
+    h = x
 
 
-    for i in range(self.shard.start_layer, self.shard.end_layer + 1):
+    for i, c in zip(range(self.shard.start_layer, self.shard.end_layer + 1), self.caches[request_id]):
       layer = self.layers[i]
       layer = self.layers[i]
-      h = layer(h, start_pos, freqs_cis, mask)
+      h = layer(h, start_pos, freqs_cis, mask, cache=c)
 
 
     if self.shard.is_last_layer():
     if self.shard.is_last_layer():
       logits = self.output(self.norm(h)).float().realize()
       logits = self.output(self.norm(h)).float().realize()
@@ -221,11 +226,26 @@ class Transformer:
     else:
     else:
       return h
       return h
 
 
-  def __call__(self, tokens: Tensor, start_pos: Variable):
+  def init_cache(self, x: Tensor, request_id: str):
+    cache = [create_kv_cache(x, self.layers[i].attention.max_context, self.layers[i].attention.n_kv_heads, self.layers[i].attention.head_dim) for i in range(self.shard.start_layer, self.shard.end_layer + 1)]
+    if len(self.caches) >= self.max_caches:
+      self.caches.popitem(last=False)
+
+    self.caches[request_id] = cache
+
+  def __call__(self, tokens: Tensor, start_pos: Variable, request_id: str):
+    if self.shard.is_first_layer():
+      h = self.tok_embeddings(tokens)
+    else:
+      h = tokens
+    if request_id not in self.caches:
+      self.init_cache(h, request_id)
+    else:
+      self.caches.move_to_end(request_id)
     # TODO: better way to handle the first call v.s. the rest?
     # TODO: better way to handle the first call v.s. the rest?
     if tokens.shape[0:2] == (1, 1) and self.forward_jit is not None:
     if tokens.shape[0:2] == (1, 1) and self.forward_jit is not None:
-      return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos))
-    return self.forward(tokens, start_pos)
+      return self.forward_jit(h, Variable("start_pos", 0, self.max_context).bind(start_pos), request_id)
+    return self.forward(h, start_pos, request_id)
 
 
 
 
 # *** helpers ***
 # *** helpers ***