|
@@ -4,8 +4,19 @@ from tinygrad.helpers import getenv
|
|
|
|
|
|
|
|
|
# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
|
|
|
-def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half) -> Tensor:
|
|
|
+def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half, rope_scaling: Optional[Dict[str, float]] = None) -> Tensor:
|
|
|
freqs = 1.0/(theta**(Tensor.arange(0, dim, 2)[:(dim // 2)]/dim))
|
|
|
+
|
|
|
+ if rope_scaling:
|
|
|
+ factor = rope_scaling.get('factor', 1.0)
|
|
|
+ low_freq_factor = rope_scaling.get('low_freq_factor', 1.0)
|
|
|
+ high_freq_factor = rope_scaling.get('high_freq_factor', 1.0)
|
|
|
+ original_max_pos_emb = rope_scaling.get('original_max_position_embeddings', end)
|
|
|
+
|
|
|
+ freqs[:dim // 4] *= low_freq_factor
|
|
|
+ freqs[dim // 4:] = freqs[dim // 4:].contiguous()*high_freq_factor
|
|
|
+ freqs *= (original_max_pos_emb/end)**(1.0/factor)
|
|
|
+
|
|
|
freqs = Tensor.arange(end).unsqueeze(dim=1)*freqs.unsqueeze(dim=0)
|
|
|
# TODO: move dtype outside this
|
|
|
return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim // 2, 2)
|
|
@@ -176,14 +187,18 @@ class Transformer:
|
|
|
rope_theta=10000,
|
|
|
max_context=1024,
|
|
|
jit=True,
|
|
|
- feed_forward=FeedForward
|
|
|
+ feed_forward=FeedForward,
|
|
|
+ rope_scaling: Optional[Dict[str, float]] = None,
|
|
|
+ tie_word_embeddings=False
|
|
|
):
|
|
|
self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
|
|
|
self.norm = nn.RMSNorm(dim, norm_eps)
|
|
|
self.tok_embeddings = nn.Embedding(vocab_size, dim)
|
|
|
self.output = nn.Linear(dim, vocab_size, bias=False)
|
|
|
+ if tie_word_embeddings:
|
|
|
+ self.output.weight = self.tok_embeddings.weight
|
|
|
self.max_context = max_context
|
|
|
- self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context*2, rope_theta).contiguous()
|
|
|
+ self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context*2, rope_theta, rope_scaling=rope_scaling).contiguous()
|
|
|
self.forward_jit = TinyJit(self.forward) if jit else None
|
|
|
self.shard = shard
|
|
|
|