|
@@ -1,6 +1,7 @@
|
|
|
from typing import Tuple, Union, Optional, Dict, Any
|
|
|
from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
|
|
|
from tinygrad.helpers import getenv
|
|
|
+from collections import OrderedDict
|
|
|
|
|
|
|
|
|
# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
|
|
@@ -47,6 +48,12 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
|
|
# NOTE: this is different from x.repeat((1, 1, n_rep, 1))
|
|
|
return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads*n_rep, head_dim)
|
|
|
|
|
|
+def create_kv_cache(x: Tensor, max_context: int, n_kv_heads: int, head_dim: int):
|
|
|
+ cache_kv = Tensor.zeros(2, x.shape[0], max_context, n_kv_heads, head_dim, dtype=x.dtype).contiguous().realize()
|
|
|
+ if isinstance(x.device, tuple):
|
|
|
+ # TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded
|
|
|
+ cache_kv.shard_((x.device), axis=3 if getenv("SHARD_KVCACHE") else None).realize()
|
|
|
+ return cache_kv.realize()
|
|
|
|
|
|
class Attention:
|
|
|
def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
|
|
@@ -61,7 +68,7 @@ class Attention:
|
|
|
self.wv = linear(dim, self.n_kv_heads*self.head_dim, bias=False)
|
|
|
self.wo = linear(self.n_heads*self.head_dim, dim, bias=False)
|
|
|
|
|
|
- def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]) -> Tensor:
|
|
|
+ def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor], cache: Optional[Tensor]=None) -> Tensor:
|
|
|
if getenv("WQKV"):
|
|
|
if not hasattr(self, 'wqkv'): self.wqkv = Tensor.cat(self.wq.weight, self.wk.weight, self.wv.weight)
|
|
|
xqkv = x @ self.wqkv.T
|
|
@@ -76,19 +83,16 @@ class Attention:
|
|
|
xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
|
|
|
bsz, seqlen, _, _ = xq.shape
|
|
|
|
|
|
- # create kv cache
|
|
|
- if not hasattr(self, "cache_kv"):
|
|
|
- self.cache_kv = Tensor.zeros(2, bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype).contiguous().realize()
|
|
|
- if isinstance(x.device, tuple):
|
|
|
- # TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded
|
|
|
- self.cache_kv.shard_((x.device), axis=3 if getenv("SHARD_KVCACHE") else None).realize()
|
|
|
+ if cache is not None:
|
|
|
+ # update the cache
|
|
|
+ assert xk.dtype == xv.dtype == cache.dtype, f"{xk.dtype=}, {xv.dtype=}, {cache.dtype=}"
|
|
|
+ cache.shrink((None, None, (start_pos, start_pos + seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
|
|
|
|
|
|
- # update the cache
|
|
|
- assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}"
|
|
|
- self.cache_kv.shrink((None, None, (start_pos, start_pos + seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
|
|
|
-
|
|
|
- keys = self.cache_kv[0].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xk
|
|
|
- values = self.cache_kv[1].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xv
|
|
|
+ keys = cache[0].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xk
|
|
|
+ values = cache[1].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xv
|
|
|
+ else:
|
|
|
+ keys = xk
|
|
|
+ values = xv
|
|
|
|
|
|
keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
|
|
|
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
|
|
@@ -114,8 +118,8 @@ class TransformerBlock:
|
|
|
self.attention_norm = nn.RMSNorm(dim, norm_eps)
|
|
|
self.ffn_norm = nn.RMSNorm(dim, norm_eps)
|
|
|
|
|
|
- def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]):
|
|
|
- h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
|
|
|
+ def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor], cache: Optional[Tensor]=None):
|
|
|
+ h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask, cache=cache)
|
|
|
return (h + self.feed_forward(self.ffn_norm(h))).contiguous()
|
|
|
|
|
|
|
|
@@ -189,7 +193,8 @@ class Transformer:
|
|
|
jit=True,
|
|
|
feed_forward=FeedForward,
|
|
|
rope_scaling: Optional[Dict[str, float]] = None,
|
|
|
- tie_word_embeddings=False
|
|
|
+ tie_word_embeddings=False,
|
|
|
+ max_caches=2,
|
|
|
):
|
|
|
self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
|
|
|
self.norm = nn.RMSNorm(dim, norm_eps)
|
|
@@ -201,19 +206,19 @@ class Transformer:
|
|
|
self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context*2, rope_theta, rope_scaling=rope_scaling).contiguous()
|
|
|
self.forward_jit = TinyJit(self.forward) if jit else None
|
|
|
self.shard = shard
|
|
|
+ self.caches = OrderedDict()
|
|
|
+ self.max_caches = max_caches
|
|
|
|
|
|
- def forward(self, x: Tensor, start_pos: Union[Variable, int]):
|
|
|
- if self.shard.is_first_layer():
|
|
|
- h = self.tok_embeddings(x)
|
|
|
- else:
|
|
|
- h = x
|
|
|
- seqlen = h.shape[1]
|
|
|
+ def forward(self, x: Tensor, start_pos: Union[Variable, int], request_id: str):
|
|
|
+ seqlen = x.shape[1]
|
|
|
freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos + seqlen), None, None, None))
|
|
|
- mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-100000000"), dtype=h.dtype, device=h.device).triu(start_pos + 1).realize() if seqlen > 1 else None
|
|
|
+ mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-100000000"), dtype=x.dtype, device=x.device).triu(start_pos + 1).realize() if seqlen > 1 else None
|
|
|
+
|
|
|
+ h = x
|
|
|
|
|
|
- for i in range(self.shard.start_layer, self.shard.end_layer + 1):
|
|
|
+ for i, c in zip(range(self.shard.start_layer, self.shard.end_layer + 1), self.caches[request_id]):
|
|
|
layer = self.layers[i]
|
|
|
- h = layer(h, start_pos, freqs_cis, mask)
|
|
|
+ h = layer(h, start_pos, freqs_cis, mask, cache=c)
|
|
|
|
|
|
if self.shard.is_last_layer():
|
|
|
logits = self.output(self.norm(h)).float().realize()
|
|
@@ -221,11 +226,26 @@ class Transformer:
|
|
|
else:
|
|
|
return h
|
|
|
|
|
|
- def __call__(self, tokens: Tensor, start_pos: Variable):
|
|
|
+ def init_cache(self, x: Tensor, request_id: str):
|
|
|
+ cache = [create_kv_cache(x, self.layers[i].attention.max_context, self.layers[i].attention.n_kv_heads, self.layers[i].attention.head_dim) for i in range(self.shard.start_layer, self.shard.end_layer + 1)]
|
|
|
+ if len(self.caches) >= self.max_caches:
|
|
|
+ self.caches.popitem(last=False)
|
|
|
+
|
|
|
+ self.caches[request_id] = cache
|
|
|
+
|
|
|
+ def __call__(self, tokens: Tensor, start_pos: Variable, request_id: str):
|
|
|
+ if self.shard.is_first_layer():
|
|
|
+ h = self.tok_embeddings(tokens)
|
|
|
+ else:
|
|
|
+ h = tokens
|
|
|
+ if request_id not in self.caches:
|
|
|
+ self.init_cache(h, request_id)
|
|
|
+ else:
|
|
|
+ self.caches.move_to_end(request_id)
|
|
|
# TODO: better way to handle the first call v.s. the rest?
|
|
|
if tokens.shape[0:2] == (1, 1) and self.forward_jit is not None:
|
|
|
- return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos))
|
|
|
- return self.forward(tokens, start_pos)
|
|
|
+ return self.forward_jit(h, Variable("start_pos", 0, self.max_context).bind(start_pos), request_id)
|
|
|
+ return self.forward(h, start_pos, request_id)
|
|
|
|
|
|
|
|
|
# *** helpers ***
|