Răsfoiți Sursa

shareded inference

Varshith 9 luni în urmă
părinte
comite
9d2616b9cf

+ 1 - 0
.gitignore

@@ -83,6 +83,7 @@ target/
 
 # Jupyter Notebook
 .ipynb_checkpoints
+Untitled.ipynb
 
 # IPython
 profile_default/

+ 117 - 102
exo/inference/mlx/models/sharded_llava.py

@@ -1,18 +1,15 @@
 # Copyright © 2024 Apple Inc.
 
 import math
-import glob
 import inspect
-import json
-from dataclasses import dataclass
-from pathlib import Path
-from typing import Optional, Dict, Union, Tuple
+from dataclasses import dataclass, field
+from typing import Optional, Dict, Union
 
 import mlx.core as mx
 import mlx.nn as nn
-from mlx_lm.models.base import KVCache
+from mlx_lm.models.base import BaseModelArgs, KVCache
+from exo.inference.shard import Shard
 import numpy as np
-from huggingface_hub import snapshot_download
 
 
 @dataclass
@@ -42,15 +39,15 @@ class VisionConfig:
 
 class VisionAttention(nn.Module):
     def __init__(
-        self,
-        dims: int,
-        num_heads: int,
-        query_input_dims: Optional[int] = None,
-        key_input_dims: Optional[int] = None,
-        value_input_dims: Optional[int] = None,
-        value_dims: Optional[int] = None,
-        value_output_dims: Optional[int] = None,
-        bias: bool = False,
+            self,
+            dims: int,
+            num_heads: int,
+            query_input_dims: Optional[int] = None,
+            key_input_dims: Optional[int] = None,
+            value_input_dims: Optional[int] = None,
+            value_dims: Optional[int] = None,
+            value_output_dims: Optional[int] = None,
+            bias: bool = False,
     ):
         super().__init__()
 
@@ -206,7 +203,7 @@ class VisionModel(nn.Module):
         self.vision_model = ClipVisionModel(config)
 
     def __call__(
-        self, x: mx.array, output_hidden_states: Optional[bool] = None
+            self, x: mx.array, output_hidden_states: Optional[bool] = None
     ) -> mx.array:
         return self.vision_model(x, output_hidden_states)
 
@@ -228,6 +225,7 @@ class VisionModel(nn.Module):
 
         return sanitized_weights
 
+
 @dataclass
 class TextConfig:
     model_type: str
@@ -235,10 +233,10 @@ class TextConfig:
     num_hidden_layers: int = 32
     intermediate_size: int = 11008
     num_attention_heads: int = 32
+    head_dim: int = None
     rms_norm_eps: float = 1e-6
     vocab_size: int = 32000
-    n_kv_heads: int = None
-    head_dim: Optional[int] = None
+    num_key_value_heads: int = None
     rope_theta: float = 10000
     rope_traditional: bool = False
     rope_scaling: Optional[Dict[str, Union[float, str]]] = None
@@ -254,12 +252,15 @@ class TextConfig:
         )
 
     def __post_init__(self):
-        if self.n_kv_heads is None:
-            self.n_kv_heads = self.num_attention_heads
+        if self.num_key_value_heads is None:
+            self.num_key_value_heads = self.num_attention_heads
 
         if self.head_dim is None:
             self.head_dim = self.hidden_size // self.num_attention_heads
 
+        if self.model_type is None:
+            self.model_type = "llama"
+
         if self.rope_scaling:
             required_keys = {"factor", "type"}
             if not all(key in self.rope_scaling for key in required_keys):
@@ -275,12 +276,12 @@ class TextAttention(nn.Module):
 
         dim = config.hidden_size
         self.n_heads = n_heads = config.num_attention_heads
-        self.n_kv_heads = n_kv_heads = config.n_kv_heads
+        self.n_kv_heads = n_kv_heads = config.num_key_value_heads
 
         self.repeats = n_heads // n_kv_heads
 
         head_dim = config.hidden_size // n_heads
-        self.scale = head_dim**-0.5
+        self.scale = head_dim ** -0.5
 
         self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
         self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
@@ -290,7 +291,7 @@ class TextAttention(nn.Module):
         rope_scale = (
             1 / config.rope_scaling["factor"]
             if config.rope_scaling is not None
-            and config.rope_scaling["type"] == "linear"
+               and config.rope_scaling["type"] == "linear"
             else 1
         )
         self.rope = nn.RoPE(
@@ -301,10 +302,10 @@ class TextAttention(nn.Module):
         )
 
     def __call__(
-        self,
-        x: mx.array,
-        mask: Optional[mx.array] = None,
-        cache: Optional[KVCache] = None,
+            self,
+            x: mx.array,
+            mask: Optional[mx.array] = None,
+            cache: Optional[KVCache] = None,
     ) -> mx.array:
         B, L, D = x.shape
 
@@ -355,10 +356,10 @@ class TransformerBlock(nn.Module):
         self.config = config
 
     def __call__(
-        self,
-        x: mx.array,
-        mask: Optional[mx.array] = None,
-        cache: Optional[KVCache] = None,
+            self,
+            x: mx.array,
+            mask: Optional[mx.array] = None,
+            cache: Optional[KVCache] = None,
     ) -> mx.array:
         r = self.self_attn(self.input_layernorm(x), mask, cache)
         h = x + r
@@ -368,12 +369,15 @@ class TransformerBlock(nn.Module):
 
 
 class Llama(nn.Module):
-    def __init__(self, config: TextConfig):
+    def __init__(self, config: TextConfig, is_first_layer, is_last_layer):
         super().__init__()
         self.config = config
+        self.is_first_layer = is_first_layer
+        self.is_last_layer = is_last_layer
         self.vocab_size = config.vocab_size
+        self.model_type = config.model_type
         self.num_hidden_layers = config.num_hidden_layers
-        self.n_kv_heads = config.n_kv_heads
+        self.num_key_value_heads = config.num_key_value_heads
         self.head_dim = config.head_dim
         assert self.vocab_size > 0
         self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
@@ -383,14 +387,17 @@ class Llama(nn.Module):
         self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
     def __call__(
-        self,
-        inputs: mx.array,
-        cache=None,
-        inputs_embeds=None,
+            self,
+            inputs: mx.array,
+            cache=None,
+            inputs_embeds=None,
     ):
         # for passing merged input embeddings
         if inputs_embeds is None:
-            h = self.embed_tokens(inputs)
+            if self.is_first_layer:
+                h = self.embed_tokens(inputs)
+            else:
+                h = inputs
         else:
             h = inputs_embeds
 
@@ -406,18 +413,20 @@ class Llama(nn.Module):
         for layer, c in zip(self.layers, cache):
             h = layer(h, mask, c)
 
-        return self.norm(h)
-
+        if self.is_last_layer:
+            h = self.norm(h)
+        return h
 
 class LanguageModel(nn.Module):
-    def __init__(self, config: TextConfig):
+    def __init__(self, config: TextConfig, is_first_layer, is_last_layer):
         super().__init__()
         self.model_type = config.model_type
         if self.model_type != "llama":
             raise ValueError(
                 f"Model type {self.model_type} not supported. Currently only 'llama' is supported"
             )
-        self.model = Llama(config)
+        self.is_last_layer = is_last_layer
+        self.model = Llama(config, is_first_layer, is_last_layer)
         self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
     def __call__(
@@ -427,7 +436,9 @@ class LanguageModel(nn.Module):
         inputs_embeds=None,
     ):
         out = self.model(inputs, cache, inputs_embeds)
-        return self.lm_head(out)
+        if self.is_last_layer:
+            out = self.lm_head(out)
+        return out
 
     @staticmethod
     def sanitize(weights):
@@ -436,11 +447,10 @@ class LanguageModel(nn.Module):
             k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
         }
 
-
 @dataclass
-class LlaVAConfig:
+class LlaVAConfig(BaseModelArgs):
     text_config: TextConfig
-    vision_config: VisionConfig
+    vision_config: VisionConfig = None
     model_type: str = "llava"
     ignore_index: int = -100
     image_token_index: int = 32000
@@ -450,13 +460,32 @@ class LlaVAConfig:
 
     @classmethod
     def from_dict(cls, params):
-        return cls(
-            **{
-                k: v
-                for k, v in params.items()
-                if k in inspect.signature(cls).parameters
-            }
-        )
+        updated_params = {}
+        class_params = inspect.signature(cls).parameters
+        for k, v in params.items():
+            if k in class_params:
+                if k in ["text_config", "vision_config"]:
+                    v = class_params[k].annotation.from_dict(v)
+                updated_params.update({k: v})
+
+        return cls(**updated_params)
+
+
+@dataclass
+class ModelArgs(LlaVAConfig):
+    shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
+
+    def __post_init__(self):
+        if isinstance(self.shard, dict):
+            self.shard = Shard(**self.shard)
+
+        if not isinstance(self.shard, Shard):
+            raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
+
+        if not self.shard.is_first_layer():
+            self.vision_config = None
+
+        self.text_config.num_hidden_layers = self.shard.get_layer_count()
 
 
 class LlavaMultiModalProjector(nn.Module):
@@ -477,19 +506,22 @@ class LlavaMultiModalProjector(nn.Module):
         return x
 
 
-class LlavaModel(nn.Module):
-    def __init__(self, config: LlaVAConfig):
+class Model(nn.Module):
+    def __init__(self, config: ModelArgs):
+        super().__init__()
         self.config = config
-        self.vision_tower = VisionModel(config.vision_config)
-        self.language_model = LanguageModel(config.text_config)
-        self.multi_modal_projector = LlavaMultiModalProjector(config)
-        self.vision_feature_layer = config.vision_feature_layer
-        self.vision_feature_select_strategy = config.vision_feature_select_strategy
+        self.model_type = config.model_type
+        if config.vision_config:
+            self.vision_tower = VisionModel(config.vision_config)
+            self.multi_modal_projector = LlavaMultiModalProjector(config)
+            self.vision_feature_layer = config.vision_feature_layer
+            self.vision_feature_select_strategy = config.vision_feature_select_strategy
+        self.language_model = LanguageModel(config.text_config, config.shard.is_first_layer(), config.shard.is_last_layer())
 
     def get_input_embeddings(
-        self,
-        input_ids: Optional[mx.array] = None,
-        pixel_values: Optional[mx.array] = None,
+            self,
+            input_ids: Optional[mx.array] = None,
+            pixel_values: Optional[mx.array] = None,
     ):
         if pixel_values is None:
             return self.language_model(input_ids)
@@ -525,7 +557,7 @@ class LlavaModel(nn.Module):
         return final_inputs_embeds
 
     def _merge_input_ids_with_image_features(
-        self, image_features, inputs_embeds, input_ids
+            self, image_features, inputs_embeds, input_ids
     ):
         image_token_index = self.config.image_token_index
         num_images, num_image_patches, embed_dim = image_features.shape
@@ -554,49 +586,32 @@ class LlavaModel(nn.Module):
         # (1, num_image_patches*num_images + sequence_len, embed_dim)
         return mx.concatenate(final_embeddings, axis=1)
 
-    def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None):
-        input_embddings = self.get_input_embeddings(input_ids, pixel_values)
+    def __call__(self, input_ids: mx.array, pixel_values: mx.array = None, cache=None):
+        input_embddings = None
+        if pixel_values is not None:
+            input_embddings = self.get_input_embeddings(input_ids, pixel_values)
         logits = self.language_model(
             input_ids, cache=cache, inputs_embeds=input_embddings
         )
         return logits
 
-    @staticmethod
-    def from_pretrained(path_or_hf_repo: str):
-        path = Path(path_or_hf_repo)
-        if not path.exists():
-            path = Path(
-                snapshot_download(
-                    repo_id=path_or_hf_repo,
-                    allow_patterns=[
-                        "*.json",
-                        "*.safetensors",
-                        "*.py",
-                        "tokenizer.model",
-                        "*.tiktoken",
-                    ],
-                )
-            )
-
-        with open(path / "config.json", "r") as f:
-            model_config = json.load(f)
-
-        model_config = LlaVAConfig.from_dict(model_config)
+    def sanitize(self, weights):
+        if self.config.vision_config:
+            weights = self.vision_tower.sanitize(weights)
+        weights = self.language_model.sanitize(weights)
 
-        model_config.vision_config = VisionConfig.from_dict(model_config.vision_config)
-        model_config.text_config = TextConfig.from_dict(model_config.text_config)
+        return weights
 
-        model = LlavaModel(model_config)
-        weight_files = glob.glob(str(path / "*.safetensors"))
-        if not weight_files:
-            raise FileNotFoundError(f"No safetensors found in {path}")
+    @property
+    def layers(self):
+        return self.language_model.model.layers
 
-        weights = {}
-        for wf in weight_files:
-            weights.update(mx.load(wf))
-
-        weights = VisionModel.sanitize(weights)
-        weights = LanguageModel.sanitize(weights)
+    @property
+    def head_dim(self):
+        return (
+                self.language_model.model.head_dim or self.language_model.model.hidden_size // self.language_model.model.num_attention_heads
+        )
 
-        model.load_weights(list(weights.items()))
-        return model
+    @property
+    def n_kv_heads(self):
+        return self.language_model.model.num_key_value_heads

+ 11 - 13
exo/inference/mlx/sharded_model.py

@@ -15,7 +15,8 @@ class StatefulShardedModel:
 
     def step(
         self,
-        x,
+        y,
+        pixel_values=None,
         temp: float = 0.0,
         top_p: float = 1.0,
         logit_bias: Optional[Dict[int, float]] = None,
@@ -36,9 +37,11 @@ class StatefulShardedModel:
 
             return token
 
-        y = x
-
-        output = self.model(y[None] if self.shard.is_first_layer() else y, cache=self.cache)
+        # TODO : revert hacky fix
+        if pixel_values is None:
+            output = self.model(y[None] if self.shard.is_first_layer() else y, cache=self.cache)
+        else:
+            output = self.model(y, pixel_values=pixel_values, cache=self.cache)
 
         if self.shard.is_last_layer():
             logits = output[:, -1, :]
@@ -57,14 +60,9 @@ class StatefulShardedModel:
         return self.step(x, temp, top_p, logit_bias)
 
     def reset(self):
-        if hasattr(self.model.config, "vision_config"):
-            model = self.model.language_model.model
-        else:
-            model = self.model
-
         kv_heads = (
-            [model.n_kv_heads] * len(model.layers)
-            if isinstance(model.n_kv_heads, int)
-            else model.n_kv_heads
+            [self.model.n_kv_heads] * len(self.model.layers)
+            if isinstance(self.model.n_kv_heads, int)
+            else self.model.n_kv_heads
         )
-        self.cache = [KVCache(model.head_dim, n) for n in kv_heads]
+        self.cache = [KVCache(self.model.head_dim, n) for n in kv_heads]

+ 14 - 59
exo/inference/mlx/sharded_utils.py

@@ -19,7 +19,6 @@ from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
 from mlx_lm.tuner.utils import apply_lora_layers
 
 from ..shard import Shard
-from exo.inference.mlx.models.sharded_llava import LlavaModel, LlaVAConfig, VisionConfig, VisionModel, TextConfig, LanguageModel
 
 class ModelNotFoundError(Exception):
     def __init__(self, message):
@@ -29,6 +28,7 @@ class ModelNotFoundError(Exception):
 MODEL_REMAPPING = {
     "sharded_mistral": "sharded_llama",  # mistral is compatible with llama
     "sharded_phi-msft": "sharded_phixtral",
+    "sharded_llava": "sharded_llava"
 }
 
 def _get_classes(config: dict):
@@ -113,6 +113,7 @@ def load_model_shard(
     for wf in weight_files:
         weights_dict = mx.load(wf)
         all_weights_keys.update(weights_dict.keys())
+        weights.update({k: v for k, v in weights_dict.items() if not k.startswith("language_model.model.layers.") or shard.start_layer <= int(k.split('.')[3]) <= shard.end_layer})
         weights.update({k: v for k, v in weights_dict.items() if not k.startswith("model.layers.") or shard.start_layer <= int(k.split('.')[2]) <= shard.end_layer})
 
     model_class, model_args_class = _get_classes(config=config)
@@ -137,6 +138,11 @@ def load_model_shard(
             if shard.start_layer <= layer_num <= shard.end_layer:
                 new_key = f"model.layers.{layer_num - shard.start_layer}." + '.'.join(k.split('.')[3:])
                 filtered_weights[new_key] = v
+        elif k.startswith("language_model.model.layers."):
+            layer_num = int(k.split('.')[3])
+            if shard.start_layer <= layer_num <= shard.end_layer:
+                new_key = f"language_model.model.layers.{layer_num - shard.start_layer}." + '.'.join(k.split('.')[4:])
+                filtered_weights[new_key] = v
         else:
             filtered_weights[k] = v
     weights = filtered_weights
@@ -228,62 +234,11 @@ async def load_shard(
     if adapter_path is not None:
         model = apply_lora_layers(model, adapter_path)
         model.eval()
-    tokenizer = load_tokenizer(model_path, tokenizer_config)
-
-    return model, tokenizer
-
-
-async def load_shard_llava(
-    path_or_hf_repo: str,
-    shard: Shard,
-    tokenizer_config={},
-    model_config={},
-    adapter_path: Optional[str] = None,
-    lazy: bool = False,
-) -> Tuple[nn.Module, TokenizerWrapper]:
-    """
-    Load the model and tokenizer from a given path or a huggingface repository.
-
-    Args:
-        path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
-        tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
-            Defaults to an empty dictionary.
-        model_config(dict, optional): Configuration parameters specifically for the model.
-            Defaults to an empty dictionary.
-        adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
-            to the model. Default: ``None``.
-        lazy (bool): If False eval the model parameters to make sure they are
-            loaded in memory before returning, otherwise they will be loaded
-            when needed. Default: ``False``
-    Returns:
-        Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
-
-    Raises:
-        FileNotFoundError: If config file or safetensors are not found.
-        ValueError: If model class or args class are not found.
-    """
-    model_path = await get_model_path(path_or_hf_repo)
-    processor = AutoProcessor.from_pretrained(model_path)
-
-    with open(model_path / "config.json", "r") as f:
-        model_config = json.load(f)
-
-    model_config = LlaVAConfig.from_dict(model_config)
-
-    model_config.vision_config = VisionConfig.from_dict(model_config.vision_config)
-    model_config.text_config = TextConfig.from_dict(model_config.text_config)
-
-    model = LlavaModel(model_config)
-    weight_files = glob.glob(str(model_path / "*.safetensors"))
-    if not weight_files:
-        raise FileNotFoundError(f"No safetensors found in {model_path}")
-
-    weights = {}
-    for wf in weight_files:
-        weights.update(mx.load(wf))
-
-    weights = VisionModel.sanitize(weights)
-    weights = LanguageModel.sanitize(weights)
 
-    model.load_weights(list(weights.items()))
-    return model, processor
+    # TODO: figure out a better way
+    if "llama" in str(model_path):
+        tokenizer = load_tokenizer(model_path, tokenizer_config)
+        return model, tokenizer
+    elif "llava" in str(model_path):
+        processor = AutoProcessor.from_pretrained(model_path)
+        return model, processor

+ 35 - 34
exo/inference/mlx/test_sharded_llava.py

@@ -9,42 +9,20 @@ import mlx.core as mx
 from mlx_lm.models.base import KVCache
 
 from exo.inference.mlx.sharded_model import StatefulShardedModel
-from exo.inference.mlx.sharded_utils import load_shard_llava
+from exo.inference.mlx.sharded_utils import load_shard
 from exo.inference.shard import Shard
 
-def sample(logits, temperature=0.0):
-    if temperature == 0:
-        return mx.argmax(logits, axis=-1)
-    else:
-        return mx.random.categorical(logits * (1 / temperature))
-def generate_text(input_ids, pixel_values, model, processor, max_tokens, temperature):
-    kv_heads = (
-        [model.language_model.model.n_kv_heads] * len(model.language_model.model.layers)
-        if isinstance(model.language_model.model.n_kv_heads, int)
-        else model.language_model.model.n_kv_heads
-    )
-    cache = [KVCache(model.language_model.model.head_dim, n) for n in kv_heads]
-    logits = model(input_ids, pixel_values, cache=cache)
-    logits = logits[:, -1, :]
-    y = sample(logits, temperature=temperature)
-    tokens = [y.item()]
-
-    for n in range(max_tokens - 1):
-        logits = model.language_model(y[None], cache=cache)
-        logits = logits[:, -1, :]
-        y = sample(logits, temperature)
-        token = y.item()
-        if token == processor.tokenizer.eos_token_id:
-            break
-        tokens.append(token)
-
-    return processor.tokenizer.decode(tokens)
-
 shard_full = Shard("llava", 0, 31, 32)
+shard1 = Shard("llava", 0, 12, 32)
+shard2 = Shard("llava", 13, 31, 32)
 
-full_model_shard, full_processor = asyncio.run(load_shard_llava("llava-hf/llava-1.5-7b-hf", shard=shard_full))
+full_model_shard, full_processor = asyncio.run(load_shard("llava-hf/llava-1.5-7b-hf", shard=shard_full))
+model_shard1, processor1 = asyncio.run(load_shard("llava-hf/llava-1.5-7b-hf", shard=shard1))
+model_shard2, processor2 = asyncio.run(load_shard("llava-hf/llava-1.5-7b-hf", shard=shard2))
 
 full = StatefulShardedModel(shard_full, full_model_shard)
+m1 = StatefulShardedModel(shard1, model_shard1)
+m2 = StatefulShardedModel(shard2, model_shard2)
 
 PROMPT = "USER: <image>\nWhat are these?\nASSISTANT:"
 IMAGE_FILE = "http://images.cocodataset.org/val2017/000000039769.jpg"
@@ -56,7 +34,30 @@ pixel_values = mx.array(inputs["pixel_values"])
 input_ids = mx.array(inputs["input_ids"])
 
 print(prompt)
-generated_text = generate_text(
-    input_ids, pixel_values, full_model_shard, full_processor, 10, 0
-)
-print(generated_text)
+y = full.step(input_ids, pixel_values, temp=0)
+full_generated_tokens = [y.item()]
+
+for _ in range(13):
+    y = full.step(y, temp=0)
+    full_generated_tokens.append(y.item())
+
+full_response = full_processor.tokenizer.decode(full_generated_tokens)
+print("full response:", full_response)
+
+inputs = processor1(prompt, img, return_tensors="np")
+pixel_values = mx.array(inputs["pixel_values"])
+input_ids = mx.array(inputs["input_ids"])
+
+y = m1.step(input_ids, pixel_values, temp=0)
+y = m2.step(y, temp=0)
+full_generated_tokens = [y.item()]
+
+for _ in range(13):
+    y = m1.step(y, temp=0)
+    y = m2.step(y, temp=0)
+    full_generated_tokens.append(y.item())
+
+sharded_response = processor2.tokenizer.decode(full_generated_tokens)
+print("sharded response:", sharded_response)
+
+assert full_response == sharded_response

+ 3 - 0
exo/inference/shard.py

@@ -13,6 +13,9 @@ class Shard:
     def is_last_layer(self) -> bool:
         return self.end_layer == self.n_layers - 1
 
+    def get_layer_count(self) -> int:
+        return self.end_layer - self.start_layer + 1
+
     def to_dict(self) -> dict:
         return {
             "model_id": self.model_id,