Browse Source

add support for llama3.1 (8b, 70b, 405b). bump mlx up to 0.16.0 and mlx-lm up to 0.16.1. fixes #66

Alex Cheema 11 months ago
parent
commit
bbfd5adc20
4 changed files with 152 additions and 35 deletions
  1. 16 7
      exo/api/chatgpt_api.py
  2. 133 25
      exo/inference/mlx/models/sharded_llama.py
  3. 1 1
      exo/inference/test_inference_engine.py
  4. 2 2
      setup.py

+ 16 - 7
exo/api/chatgpt_api.py

@@ -17,6 +17,15 @@ shard_mappings = {
         "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
         "TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-8b-sfr", start_layer=0, end_layer=0, n_layers=32),
     },
+    "llama-3.1-8b": {
+        "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
+    },
+    "llama-3.1-70b": {
+        "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
+    },
+    "llama-3.1-405b": {
+        "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=126),
+    },
     "llama-3-70b": {
         "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
         "TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-70b-sfr", start_layer=0, end_layer=0, n_layers=80),
@@ -42,7 +51,7 @@ def resolve_tinygrad_tokenizer(model_id: str):
     else:
         raise ValueError(f"tinygrad doesnt currently support arbitrary model downloading. unsupported model: {model_id}")
 
-def resolve_tokenizer(model_id: str):
+async def resolve_tokenizer(model_id: str):
     try:
         if DEBUG >= 2: print(f"Trying AutoTokenizer for {model_id}")
         return AutoTokenizer.from_pretrained(model_id)
@@ -61,7 +70,7 @@ def resolve_tokenizer(model_id: str):
 
     if DEBUG >= 2: print(f"Trying mlx tokenizer for {model_id}")
     from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
-    return load_tokenizer(get_model_path(model_id))
+    return load_tokenizer(await get_model_path(model_id))
 
 def generate_completion(
         chat_request: ChatCompletionRequest,
@@ -146,24 +155,24 @@ class ChatGPTAPI:
 
     async def handle_post_chat_token_encode(self, request):
         data = await request.json()
-        shard = shard_mappings.get(data.get('model', 'llama-3-8b'), {}).get(self.inference_engine_classname)
+        shard = shard_mappings.get(data.get('model', 'llama-3.1-8b'), {}).get(self.inference_engine_classname)
         messages = data.get('messages', [])
-        tokenizer = resolve_tokenizer(shard.model_id)
+        tokenizer = await resolve_tokenizer(shard.model_id)
         return web.json_response({'length': len(build_prompt(tokenizer, messages))})
 
     async def handle_post_chat_completions(self, request):
         data = await request.json()
         stream = data.get('stream', False)
         messages = [Message(**msg) for msg in data['messages']]
-        chat_request = ChatCompletionRequest(data.get('model', 'llama-3-8b'), messages, data.get('temperature', 0.0))
+        chat_request = ChatCompletionRequest(data.get('model', 'llama-3.1-8b'), messages, data.get('temperature', 0.0))
         if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
-            chat_request.model = "llama-3-8b"
+            chat_request.model = "llama-3.1-8b"
         shard = shard_mappings.get(chat_request.model, {}).get(self.inference_engine_classname)
         if not shard:
             return web.json_response({'detail': f"Invalid model: {chat_request.model}. Supported: {list(shard_mappings.keys())}"}, status=400)
         request_id = str(uuid.uuid4())
 
-        tokenizer = resolve_tokenizer(shard.model_id)
+        tokenizer = await resolve_tokenizer(shard.model_id)
         if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
 
         prompt = build_prompt(tokenizer, messages)

+ 133 - 25
exo/inference/mlx/models/sharded_llama.py

@@ -4,9 +4,8 @@ from typing import Dict, Optional, Tuple, Union
 import mlx.core as mx
 import mlx.nn as nn
 
-from mlx_lm.models.base import BaseModelArgs, create_additive_causal_mask
-from ...shard import Shard
-
+from exo.inference.shard import Shard
+from mlx_lm.models.base import BaseModelArgs, KVCache, create_additive_causal_mask
 
 @dataclass
 class NormalModelArgs(BaseModelArgs):
@@ -17,7 +16,9 @@ class NormalModelArgs(BaseModelArgs):
     num_attention_heads: int
     rms_norm_eps: float
     vocab_size: int
-    num_key_value_heads: int = None
+    head_dim: Optional[int] = None
+    max_position_embeddings: Optional[int] = None
+    num_key_value_heads: Optional[int] = None
     attention_bias: bool = False
     mlp_bias: bool = False
     rope_theta: float = 10000
@@ -30,12 +31,20 @@ class NormalModelArgs(BaseModelArgs):
             self.num_key_value_heads = self.num_attention_heads
 
         if self.rope_scaling:
-            required_keys = {"factor", "type"}
-            if not all(key in self.rope_scaling for key in required_keys):
-                raise ValueError(f"rope_scaling must contain keys {required_keys}")
+            if not "factor" in self.rope_scaling:
+                raise ValueError(f"rope_scaling must contain 'factor'")
+            rope_type = self.rope_scaling.get("type") or self.rope_scaling.get(
+                "rope_type"
+            )
+            if rope_type is None:
+                raise ValueError(
+                    f"rope_scaling must contain either 'type' or 'rope_type'"
+                )
+            if rope_type not in ["linear", "dynamic", "llama3"]:
+                raise ValueError(
+                    "rope_scaling 'type' currently only supports 'linear', 'dynamic' or 'llama3'"
+                )
 
-            if self.rope_scaling["type"] != "linear":
-                raise ValueError("rope_scaling 'type' currently only supports 'linear'")
 @dataclass
 class ModelArgs(NormalModelArgs):
     shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
@@ -50,6 +59,113 @@ class ModelArgs(NormalModelArgs):
 
         self.shard = Shard(**self.shard)
 
+class DynamicNTKScalingRoPE(nn.Module):
+    """Implements the rotary positional encoding with Dynamic NTK scaling and Llama 3 RoPE."""
+
+    def __init__(
+        self,
+        dims: int,
+        max_position_embeddings: int = 2048,
+        traditional: bool = False,
+        base: float = 10000,
+        scale: float = 1.0,
+        rope_type: str = "default",
+        rope_scaling: dict = None,
+    ):
+        super().__init__()
+        self.dims = dims
+        self.max_position_embeddings = max_position_embeddings
+        self.traditional = traditional
+        self.original_base = base
+        self.scale = scale
+        self.rope_type = rope_type
+        self.rope_scaling = rope_scaling
+        self.base = self.compute_base_freq()
+
+    def compute_base_freq(self):
+        if self.rope_type == "llama3":
+            return self.compute_llama3_base_freq()
+        return self.original_base
+
+    # source: https://github.com/huggingface/transformers/blob/d5a99dfcee6e94065cb7c83cc8ab6fc5daa0cc4e/src/transformers/modeling_rope_utils.py#L318
+    def compute_llama3_base_freq(self):
+        factor = self.rope_scaling["factor"]
+        low_freq_factor = self.rope_scaling.get("low_freq_factor", 1.0)
+        high_freq_factor = self.rope_scaling.get("high_freq_factor", 4.0)
+        old_context_len = self.rope_scaling.get(
+            "original_max_position_embeddings",
+            8192,
+        )
+
+        low_freq_wavelen = old_context_len / low_freq_factor
+        high_freq_wavelen = old_context_len / high_freq_factor
+
+        freqs = self.original_base ** (mx.arange(0, self.dims, 2) / self.dims)
+        wavelens = 2 * mx.pi * freqs
+        new_base_freqs = []
+
+        smooths = (wavelens - high_freq_wavelen) / (
+            low_freq_wavelen - high_freq_wavelen
+        )
+        new_base_freqs = freqs * (1 - smooths) * factor + smooths
+        new_base_freqs = mx.where(wavelens < high_freq_wavelen, freqs, new_base_freqs)
+        new_base_freqs = mx.where(
+            wavelens > low_freq_wavelen, freqs * factor, new_base_freqs
+        )
+        return new_base_freqs.mean().item()
+
+    def extra_repr(self):
+        return (
+            f"{self.dims}, traditional={self.traditional}, "
+            f"max_position_embeddings={self.max_position_embeddings}, "
+            f"scaling_factor={self.scale}, rope_type={self.rope_type}"
+        )
+
+    def __call__(self, x, offset: int = 0):
+        seq_len = x.shape[1] + offset
+        base = self.base
+        if self.max_position_embeddings and seq_len > self.max_position_embeddings:
+            base *= (
+                (self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1)
+            ) ** (self.dims / (self.dims - 2))
+
+        return mx.fast.rope(
+            x,
+            self.dims,
+            traditional=self.traditional,
+            base=base,
+            scale=self.scale,
+            offset=offset,
+        )
+
+
+def initialize_rope(args: ModelArgs):
+    head_dim = args.head_dim or args.hidden_size // args.num_attention_heads
+
+    rope_scaling = args.rope_scaling
+    rope_type = "default"
+    rope_scale = 1.0
+
+    if rope_scaling is not None:
+        rope_type = (
+            rope_scaling.get("type") or rope_scaling.get("rope_type") or "default"
+        )
+        if rope_type == "linear":
+            rope_scale = 1 / rope_scaling["factor"]
+        elif rope_type == "llama3":
+            rope_scale = 1.0  # The scaling is handled internally for llama3
+
+    return DynamicNTKScalingRoPE(
+        dims=head_dim,
+        max_position_embeddings=args.max_position_embeddings,
+        traditional=args.rope_traditional,
+        base=args.rope_theta,
+        scale=rope_scale,
+        rope_type=rope_type,
+        rope_scaling=rope_scaling,
+    )
+
+
 class Attention(nn.Module):
     def __init__(self, args: ModelArgs):
         super().__init__()
@@ -58,7 +174,8 @@ class Attention(nn.Module):
         self.n_heads = n_heads = args.num_attention_heads
         self.n_kv_heads = n_kv_heads = args.num_key_value_heads
 
-        head_dim = args.hidden_size // n_heads
+        self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads
+
         self.scale = head_dim**-0.5
         if hasattr(args, "attention_bias"):
             attention_bias = args.attention_bias
@@ -70,23 +187,13 @@ class Attention(nn.Module):
         self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
         self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
 
-        rope_scale = (
-            1 / args.rope_scaling["factor"]
-            if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
-            else 1
-        )
-        self.rope = nn.RoPE(
-            head_dim,
-            traditional=args.rope_traditional,
-            base=args.rope_theta,
-            scale=rope_scale,
-        )
+        self.rope = initialize_rope(args)
 
     def __call__(
         self,
         x: mx.array,
         mask: Optional[mx.array] = None,
-        cache: Optional[Tuple[mx.array, mx.array]] = None,
+        cache: Optional[KVCache] = None,
     ) -> mx.array:
         B, L, D = x.shape
 
@@ -148,7 +255,7 @@ class TransformerBlock(nn.Module):
         self,
         x: mx.array,
         mask: Optional[mx.array] = None,
-        cache: Optional[Tuple[mx.array, mx.array]] = None,
+        cache: Optional[KVCache] = None,
     ) -> mx.array:
         r = self.self_attn(self.input_layernorm(x), mask, cache)
         h = x + r
@@ -223,7 +330,6 @@ class Model(nn.Module):
 
         return out
 
-
     def sanitize(self, weights):
         # Remove unused precomputed rotary freqs
         return {
@@ -236,7 +342,9 @@ class Model(nn.Module):
 
     @property
     def head_dim(self):
-        return self.args.hidden_size // self.args.num_attention_heads
+        return (
+            self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads
+        )
 
     @property
     def n_kv_heads(self):

+ 1 - 1
exo/inference/test_inference_engine.py

@@ -27,7 +27,7 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
 asyncio.run(test_inference_engine(
     MLXDynamicShardInferenceEngine(),
     MLXDynamicShardInferenceEngine(),
-    "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
+    "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
 ))
 
 # TODO: Need more memory or a smaller model

+ 2 - 2
setup.py

@@ -30,8 +30,8 @@ install_requires = [
 if sys.platform.startswith("darwin"):
     install_requires.extend(
         [
-            "mlx==0.15.1",
-            "mlx-lm==0.14.3",
+            "mlx==0.16.0",
+            "mlx-lm==0.16.1",
         ]
     )