|
@@ -4,9 +4,8 @@ from typing import Dict, Optional, Tuple, Union
|
|
|
import mlx.core as mx
|
|
|
import mlx.nn as nn
|
|
|
|
|
|
-from mlx_lm.models.base import BaseModelArgs, create_additive_causal_mask
|
|
|
-from ...shard import Shard
|
|
|
-
|
|
|
+from exo.inference.shard import Shard
|
|
|
+from mlx_lm.models.base import BaseModelArgs, KVCache, create_additive_causal_mask
|
|
|
|
|
|
@dataclass
|
|
|
class NormalModelArgs(BaseModelArgs):
|
|
@@ -17,7 +16,9 @@ class NormalModelArgs(BaseModelArgs):
|
|
|
num_attention_heads: int
|
|
|
rms_norm_eps: float
|
|
|
vocab_size: int
|
|
|
- num_key_value_heads: int = None
|
|
|
+ head_dim: Optional[int] = None
|
|
|
+ max_position_embeddings: Optional[int] = None
|
|
|
+ num_key_value_heads: Optional[int] = None
|
|
|
attention_bias: bool = False
|
|
|
mlp_bias: bool = False
|
|
|
rope_theta: float = 10000
|
|
@@ -30,12 +31,20 @@ class NormalModelArgs(BaseModelArgs):
|
|
|
self.num_key_value_heads = self.num_attention_heads
|
|
|
|
|
|
if self.rope_scaling:
|
|
|
- required_keys = {"factor", "type"}
|
|
|
- if not all(key in self.rope_scaling for key in required_keys):
|
|
|
- raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
|
|
+ if not "factor" in self.rope_scaling:
|
|
|
+ raise ValueError(f"rope_scaling must contain 'factor'")
|
|
|
+ rope_type = self.rope_scaling.get("type") or self.rope_scaling.get(
|
|
|
+ "rope_type"
|
|
|
+ )
|
|
|
+ if rope_type is None:
|
|
|
+ raise ValueError(
|
|
|
+ f"rope_scaling must contain either 'type' or 'rope_type'"
|
|
|
+ )
|
|
|
+ if rope_type not in ["linear", "dynamic", "llama3"]:
|
|
|
+ raise ValueError(
|
|
|
+ "rope_scaling 'type' currently only supports 'linear', 'dynamic' or 'llama3'"
|
|
|
+ )
|
|
|
|
|
|
- if self.rope_scaling["type"] != "linear":
|
|
|
- raise ValueError("rope_scaling 'type' currently only supports 'linear'")
|
|
|
@dataclass
|
|
|
class ModelArgs(NormalModelArgs):
|
|
|
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
|
|
@@ -50,6 +59,113 @@ class ModelArgs(NormalModelArgs):
|
|
|
|
|
|
self.shard = Shard(**self.shard)
|
|
|
|
|
|
+class DynamicNTKScalingRoPE(nn.Module):
|
|
|
+ """Implements the rotary positional encoding with Dynamic NTK scaling and Llama 3 RoPE."""
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ dims: int,
|
|
|
+ max_position_embeddings: int = 2048,
|
|
|
+ traditional: bool = False,
|
|
|
+ base: float = 10000,
|
|
|
+ scale: float = 1.0,
|
|
|
+ rope_type: str = "default",
|
|
|
+ rope_scaling: dict = None,
|
|
|
+ ):
|
|
|
+ super().__init__()
|
|
|
+ self.dims = dims
|
|
|
+ self.max_position_embeddings = max_position_embeddings
|
|
|
+ self.traditional = traditional
|
|
|
+ self.original_base = base
|
|
|
+ self.scale = scale
|
|
|
+ self.rope_type = rope_type
|
|
|
+ self.rope_scaling = rope_scaling
|
|
|
+ self.base = self.compute_base_freq()
|
|
|
+
|
|
|
+ def compute_base_freq(self):
|
|
|
+ if self.rope_type == "llama3":
|
|
|
+ return self.compute_llama3_base_freq()
|
|
|
+ return self.original_base
|
|
|
+
|
|
|
+ # source: https://github.com/huggingface/transformers/blob/d5a99dfcee6e94065cb7c83cc8ab6fc5daa0cc4e/src/transformers/modeling_rope_utils.py#L318
|
|
|
+ def compute_llama3_base_freq(self):
|
|
|
+ factor = self.rope_scaling["factor"]
|
|
|
+ low_freq_factor = self.rope_scaling.get("low_freq_factor", 1.0)
|
|
|
+ high_freq_factor = self.rope_scaling.get("high_freq_factor", 4.0)
|
|
|
+ old_context_len = self.rope_scaling.get(
|
|
|
+ "original_max_position_embeddings",
|
|
|
+ 8192,
|
|
|
+ )
|
|
|
+
|
|
|
+ low_freq_wavelen = old_context_len / low_freq_factor
|
|
|
+ high_freq_wavelen = old_context_len / high_freq_factor
|
|
|
+
|
|
|
+ freqs = self.original_base ** (mx.arange(0, self.dims, 2) / self.dims)
|
|
|
+ wavelens = 2 * mx.pi * freqs
|
|
|
+ new_base_freqs = []
|
|
|
+
|
|
|
+ smooths = (wavelens - high_freq_wavelen) / (
|
|
|
+ low_freq_wavelen - high_freq_wavelen
|
|
|
+ )
|
|
|
+ new_base_freqs = freqs * (1 - smooths) * factor + smooths
|
|
|
+ new_base_freqs = mx.where(wavelens < high_freq_wavelen, freqs, new_base_freqs)
|
|
|
+ new_base_freqs = mx.where(
|
|
|
+ wavelens > low_freq_wavelen, freqs * factor, new_base_freqs
|
|
|
+ )
|
|
|
+ return new_base_freqs.mean().item()
|
|
|
+
|
|
|
+ def extra_repr(self):
|
|
|
+ return (
|
|
|
+ f"{self.dims}, traditional={self.traditional}, "
|
|
|
+ f"max_position_embeddings={self.max_position_embeddings}, "
|
|
|
+ f"scaling_factor={self.scale}, rope_type={self.rope_type}"
|
|
|
+ )
|
|
|
+
|
|
|
+ def __call__(self, x, offset: int = 0):
|
|
|
+ seq_len = x.shape[1] + offset
|
|
|
+ base = self.base
|
|
|
+ if self.max_position_embeddings and seq_len > self.max_position_embeddings:
|
|
|
+ base *= (
|
|
|
+ (self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1)
|
|
|
+ ) ** (self.dims / (self.dims - 2))
|
|
|
+
|
|
|
+ return mx.fast.rope(
|
|
|
+ x,
|
|
|
+ self.dims,
|
|
|
+ traditional=self.traditional,
|
|
|
+ base=base,
|
|
|
+ scale=self.scale,
|
|
|
+ offset=offset,
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+def initialize_rope(args: ModelArgs):
|
|
|
+ head_dim = args.head_dim or args.hidden_size // args.num_attention_heads
|
|
|
+
|
|
|
+ rope_scaling = args.rope_scaling
|
|
|
+ rope_type = "default"
|
|
|
+ rope_scale = 1.0
|
|
|
+
|
|
|
+ if rope_scaling is not None:
|
|
|
+ rope_type = (
|
|
|
+ rope_scaling.get("type") or rope_scaling.get("rope_type") or "default"
|
|
|
+ )
|
|
|
+ if rope_type == "linear":
|
|
|
+ rope_scale = 1 / rope_scaling["factor"]
|
|
|
+ elif rope_type == "llama3":
|
|
|
+ rope_scale = 1.0 # The scaling is handled internally for llama3
|
|
|
+
|
|
|
+ return DynamicNTKScalingRoPE(
|
|
|
+ dims=head_dim,
|
|
|
+ max_position_embeddings=args.max_position_embeddings,
|
|
|
+ traditional=args.rope_traditional,
|
|
|
+ base=args.rope_theta,
|
|
|
+ scale=rope_scale,
|
|
|
+ rope_type=rope_type,
|
|
|
+ rope_scaling=rope_scaling,
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
class Attention(nn.Module):
|
|
|
def __init__(self, args: ModelArgs):
|
|
|
super().__init__()
|
|
@@ -58,7 +174,8 @@ class Attention(nn.Module):
|
|
|
self.n_heads = n_heads = args.num_attention_heads
|
|
|
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
|
|
|
|
|
- head_dim = args.hidden_size // n_heads
|
|
|
+ self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads
|
|
|
+
|
|
|
self.scale = head_dim**-0.5
|
|
|
if hasattr(args, "attention_bias"):
|
|
|
attention_bias = args.attention_bias
|
|
@@ -70,23 +187,13 @@ class Attention(nn.Module):
|
|
|
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
|
|
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
|
|
|
|
|
|
- rope_scale = (
|
|
|
- 1 / args.rope_scaling["factor"]
|
|
|
- if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
|
|
|
- else 1
|
|
|
- )
|
|
|
- self.rope = nn.RoPE(
|
|
|
- head_dim,
|
|
|
- traditional=args.rope_traditional,
|
|
|
- base=args.rope_theta,
|
|
|
- scale=rope_scale,
|
|
|
- )
|
|
|
+ self.rope = initialize_rope(args)
|
|
|
|
|
|
def __call__(
|
|
|
self,
|
|
|
x: mx.array,
|
|
|
mask: Optional[mx.array] = None,
|
|
|
- cache: Optional[Tuple[mx.array, mx.array]] = None,
|
|
|
+ cache: Optional[KVCache] = None,
|
|
|
) -> mx.array:
|
|
|
B, L, D = x.shape
|
|
|
|
|
@@ -148,7 +255,7 @@ class TransformerBlock(nn.Module):
|
|
|
self,
|
|
|
x: mx.array,
|
|
|
mask: Optional[mx.array] = None,
|
|
|
- cache: Optional[Tuple[mx.array, mx.array]] = None,
|
|
|
+ cache: Optional[KVCache] = None,
|
|
|
) -> mx.array:
|
|
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
|
|
h = x + r
|
|
@@ -223,7 +330,6 @@ class Model(nn.Module):
|
|
|
|
|
|
return out
|
|
|
|
|
|
-
|
|
|
def sanitize(self, weights):
|
|
|
# Remove unused precomputed rotary freqs
|
|
|
return {
|
|
@@ -236,7 +342,9 @@ class Model(nn.Module):
|
|
|
|
|
|
@property
|
|
|
def head_dim(self):
|
|
|
- return self.args.hidden_size // self.args.num_attention_heads
|
|
|
+ return (
|
|
|
+ self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads
|
|
|
+ )
|
|
|
|
|
|
@property
|
|
|
def n_kv_heads(self):
|