|
@@ -1,4 +1,4 @@
|
|
|
-from typing import Tuple, Union, Optional, Dict, Any
|
|
|
+from typing import Tuple, Union, Optional, Dict, Any, List
|
|
|
from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
|
|
|
from tinygrad.helpers import getenv
|
|
|
from collections import OrderedDict
|
|
@@ -48,13 +48,6 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
|
|
# NOTE: this is different from x.repeat((1, 1, n_rep, 1))
|
|
|
return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads*n_rep, head_dim)
|
|
|
|
|
|
-def create_kv_cache(x: Tensor, max_context: int, n_kv_heads: int, head_dim: int):
|
|
|
- cache_kv = Tensor.zeros(2, x.shape[0], max_context, n_kv_heads, head_dim, dtype=x.dtype).contiguous().realize()
|
|
|
- if isinstance(x.device, tuple):
|
|
|
- # TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded
|
|
|
- cache_kv.shard_((x.device), axis=3 if getenv("SHARD_KVCACHE") else None).realize()
|
|
|
- return cache_kv.realize()
|
|
|
-
|
|
|
class Attention:
|
|
|
def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
|
|
|
self.n_heads = n_heads
|
|
@@ -194,7 +187,6 @@ class Transformer:
|
|
|
feed_forward=FeedForward,
|
|
|
rope_scaling: Optional[Dict[str, float]] = None,
|
|
|
tie_word_embeddings=False,
|
|
|
- max_caches=2,
|
|
|
):
|
|
|
self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
|
|
|
self.norm = nn.RMSNorm(dim, norm_eps)
|
|
@@ -206,17 +198,17 @@ class Transformer:
|
|
|
self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context*2, rope_theta, rope_scaling=rope_scaling).contiguous()
|
|
|
self.forward_jit = TinyJit(self.forward) if jit else None
|
|
|
self.shard = shard
|
|
|
- self.caches = OrderedDict()
|
|
|
- self.max_caches = max_caches
|
|
|
|
|
|
- def forward(self, x: Tensor, start_pos: Union[Variable, int], request_id: str):
|
|
|
+ def forward(self, x: Tensor, start_pos: Union[Variable, int], cache: Optional[List[Tensor]] = None):
|
|
|
seqlen = x.shape[1]
|
|
|
freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos + seqlen), None, None, None))
|
|
|
mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-100000000"), dtype=x.dtype, device=x.device).triu(start_pos + 1).realize() if seqlen > 1 else None
|
|
|
|
|
|
h = x
|
|
|
|
|
|
- for i, c in zip(range(self.shard.start_layer, self.shard.end_layer + 1), self.caches[request_id]):
|
|
|
+ if cache is None:
|
|
|
+ cache = [None for _ in range(self.shard.start_layer, self.shard.end_layer + 1)]
|
|
|
+ for i, c in zip(range(self.shard.start_layer, self.shard.end_layer + 1), cache):
|
|
|
layer = self.layers[i]
|
|
|
h = layer(h, start_pos, freqs_cis, mask, cache=c)
|
|
|
|
|
@@ -226,26 +218,19 @@ class Transformer:
|
|
|
else:
|
|
|
return h
|
|
|
|
|
|
- def init_cache(self, x: Tensor, request_id: str):
|
|
|
- cache = [create_kv_cache(x, self.layers[i].attention.max_context, self.layers[i].attention.n_kv_heads, self.layers[i].attention.head_dim) for i in range(self.shard.start_layer, self.shard.end_layer + 1)]
|
|
|
- if len(self.caches) >= self.max_caches:
|
|
|
- self.caches.popitem(last=False)
|
|
|
-
|
|
|
- self.caches[request_id] = cache
|
|
|
-
|
|
|
- def __call__(self, tokens: Tensor, start_pos: Variable, request_id: str):
|
|
|
+ def embed(self, inputs: Tensor):
|
|
|
if self.shard.is_first_layer():
|
|
|
- h = self.tok_embeddings(tokens)
|
|
|
+ h = self.tok_embeddings(inputs)
|
|
|
else:
|
|
|
- h = tokens
|
|
|
- if request_id not in self.caches:
|
|
|
- self.init_cache(h, request_id)
|
|
|
- else:
|
|
|
- self.caches.move_to_end(request_id)
|
|
|
+ h = inputs
|
|
|
+ return h
|
|
|
+
|
|
|
+ def __call__(self, tokens: Tensor, start_pos: Variable, request_id: str, cache: Optional[List[Tensor]] = None):
|
|
|
# TODO: better way to handle the first call v.s. the rest?
|
|
|
+ h = self.embed(x)
|
|
|
if tokens.shape[0:2] == (1, 1) and self.forward_jit is not None:
|
|
|
- return self.forward_jit(h, Variable("start_pos", 0, self.max_context).bind(start_pos), request_id)
|
|
|
- return self.forward(h, start_pos, request_id)
|
|
|
+ return self.forward_jit(h, Variable("start_pos", 0, self.max_context).bind(start_pos), cache=cache)
|
|
|
+ return self.forward(h, start_pos, cache=cache)
|
|
|
|
|
|
|
|
|
# *** helpers ***
|