|
@@ -1,6 +1,7 @@
|
|
|
from typing import Tuple, Union, Optional, Dict, Any
|
|
|
from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
|
|
|
from tinygrad.helpers import getenv
|
|
|
+from exo.inference.shard import Shard
|
|
|
|
|
|
# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
|
|
|
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half) -> Tensor:
|
|
@@ -144,42 +145,47 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
|
|
|
return output_token
|
|
|
|
|
|
class Transformer:
|
|
|
- def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward):
|
|
|
- self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
|
|
|
+ def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, shard: Shard, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward):
|
|
|
+ self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(shard.end_layer - shard.start_layer + 1)]
|
|
|
self.norm = nn.RMSNorm(dim, norm_eps)
|
|
|
self.tok_embeddings = nn.Embedding(vocab_size, dim)
|
|
|
self.output = nn.Linear(dim, vocab_size, bias=False)
|
|
|
self.max_context = max_context
|
|
|
self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context * 2, rope_theta)
|
|
|
self.forward_jit = TinyJit(self.forward) if jit else None
|
|
|
+ self.shard = shard
|
|
|
|
|
|
- def forward(self, tokens:Tensor, start_pos:Union[Variable,int], temperature:float, top_k:int, top_p:float, alpha_f:float, alpha_p:float):
|
|
|
- _bsz, seqlen = tokens.shape
|
|
|
+ def forward(self, h:Tensor, start_pos:Union[Variable,int], temperature:float, top_k:int, top_p:float, alpha_f:float, alpha_p:float):
|
|
|
+ seqlen = h.shape[1]
|
|
|
freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos+seqlen),None,None,None))
|
|
|
|
|
|
- h = self.tok_embeddings(tokens)
|
|
|
+ if self.shard.is_first_layer():
|
|
|
+ h = self.tok_embeddings(h)
|
|
|
mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos+1).realize() if seqlen > 1 else None
|
|
|
+
|
|
|
for i, layer in enumerate(self.layers):
|
|
|
h = layer(h, start_pos, freqs_cis, mask)
|
|
|
- print(f"layer {i}", h.tolist().__str__()[0:100])
|
|
|
- logits = self.output(self.norm(h)).float()[:, -1, :]
|
|
|
|
|
|
- return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p).realize()
|
|
|
+ if self.shard.is_last_layer():
|
|
|
+ logits = self.output(self.norm(h)).float()[:, -1, :]
|
|
|
+ return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p).realize()
|
|
|
+ else:
|
|
|
+ return h.realize()
|
|
|
|
|
|
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0, top_k:int=0, top_p:float=0.8, alpha_f:float=0.0, alpha_p:float=0.0):
|
|
|
# TODO: better way to handle the first call v.s. the rest?
|
|
|
- if tokens.shape[0:2] == (1,1) and self.forward_jit is not None:
|
|
|
- return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
|
|
|
+ # if tokens.shape[0:2] == (1,1) and self.forward_jit is not None:
|
|
|
+ # return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
|
|
|
return self.forward(tokens, start_pos, temperature, top_k, top_p, alpha_f, alpha_p)
|
|
|
|
|
|
def reset(self):
|
|
|
for layer in self.layers:
|
|
|
- print(f"reset layer: {layer.attention.cache_kv}")
|
|
|
- layer.attention.cache_kv = layer.attention.cache_kv.zeros_like()
|
|
|
+ if hasattr(layer.attention, "cache_kv"):
|
|
|
+ layer.attention.cache_kv = layer.attention.cache_kv.zeros_like()
|
|
|
|
|
|
# *** helpers ***
|
|
|
|
|
|
-def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int):
|
|
|
+def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int, shard: Shard):
|
|
|
def permute(v: Tensor, n_heads: int):
|
|
|
return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
|
|
|
|
|
@@ -197,6 +203,12 @@ def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_he
|
|
|
if ".rotary_emb." in k: continue
|
|
|
v = v.to(Device.DEFAULT)
|
|
|
if "model.layers" in k:
|
|
|
+ layer_num = int(k.split('.')[2])
|
|
|
+ if shard.start_layer <= layer_num <= shard.end_layer:
|
|
|
+ k = f"model.layers.{layer_num - shard.start_layer}." + '.'.join(k.split('.')[3:])
|
|
|
+ else:
|
|
|
+ continue
|
|
|
+
|
|
|
if "q_proj" in k:
|
|
|
v = permute(v, n_heads)
|
|
|
elif "k_proj" in k:
|