Browse Source

added rope_scaling and tie_word_embeddings to llama transformer

Ogden Wells 7 months ago
parent
commit
af01b23a07
1 changed files with 18 additions and 3 deletions
  1. 18 3
      exo/inference/tinygrad/models/llama.py

+ 18 - 3
exo/inference/tinygrad/models/llama.py

@@ -4,8 +4,19 @@ from tinygrad.helpers import getenv
 
 
 # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
-def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half) -> Tensor:
+def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half, rope_scaling: Optional[Dict[str, float]] = None) -> Tensor:
   freqs = 1.0/(theta**(Tensor.arange(0, dim, 2)[:(dim // 2)]/dim))
+
+  if rope_scaling:
+    factor = rope_scaling.get('factor', 1.0)
+    low_freq_factor = rope_scaling.get('low_freq_factor', 1.0)
+    high_freq_factor = rope_scaling.get('high_freq_factor', 1.0)
+    original_max_pos_emb = rope_scaling.get('original_max_position_embeddings', end)
+
+    freqs[:dim // 4] *= low_freq_factor
+    freqs[dim // 4:] = freqs[dim // 4:].contiguous() * high_freq_factor
+    freqs *= (original_max_pos_emb / end) ** (1.0 / factor)
+
   freqs = Tensor.arange(end).unsqueeze(dim=1)*freqs.unsqueeze(dim=0)
   # TODO: move dtype outside this
   return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim // 2, 2)
@@ -176,14 +187,18 @@ class Transformer:
     rope_theta=10000,
     max_context=1024,
     jit=True,
-    feed_forward=FeedForward
+    feed_forward=FeedForward,
+    rope_scaling: Optional[Dict[str, float]] = None,
+    tie_word_embeddings=False
   ):
     self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
     self.norm = nn.RMSNorm(dim, norm_eps)
     self.tok_embeddings = nn.Embedding(vocab_size, dim)
     self.output = nn.Linear(dim, vocab_size, bias=False)
+    if tie_word_embeddings:
+      self.output.weight = self.tok_embeddings.weight
     self.max_context = max_context
-    self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context*2, rope_theta).contiguous()
+    self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context*2, rope_theta, rope_scaling=rope_scaling).contiguous()
     self.forward_jit = TinyJit(self.forward) if jit else None
     self.shard = shard