Browse Source

working test

Varshith 1 year ago
parent
commit
7cbf6a35bd

+ 28 - 21
exo/inference/mlx/models/sharded_llava.py

@@ -10,6 +10,7 @@ from typing import Optional, Dict, Union, Tuple
 
 import mlx.core as mx
 import mlx.nn as nn
+from mlx_lm.models.base import KVCache
 import numpy as np
 from huggingface_hub import snapshot_download
 
@@ -236,7 +237,8 @@ class TextConfig:
     num_attention_heads: int = 32
     rms_norm_eps: float = 1e-6
     vocab_size: int = 32000
-    num_key_value_heads: int = None
+    n_kv_heads: int = None
+    head_dim: Optional[int] = None
     rope_theta: float = 10000
     rope_traditional: bool = False
     rope_scaling: Optional[Dict[str, Union[float, str]]] = None
@@ -252,8 +254,11 @@ class TextConfig:
         )
 
     def __post_init__(self):
-        if self.num_key_value_heads is None:
-            self.num_key_value_heads = self.num_attention_heads
+        if self.n_kv_heads is None:
+            self.n_kv_heads = self.num_attention_heads
+
+        if self.head_dim is None:
+            self.head_dim = self.hidden_size // self.num_attention_heads
 
         if self.rope_scaling:
             required_keys = {"factor", "type"}
@@ -270,7 +275,7 @@ class TextAttention(nn.Module):
 
         dim = config.hidden_size
         self.n_heads = n_heads = config.num_attention_heads
-        self.n_kv_heads = n_kv_heads = config.num_key_value_heads
+        self.n_kv_heads = n_kv_heads = config.n_kv_heads
 
         self.repeats = n_heads // n_kv_heads
 
@@ -299,7 +304,7 @@ class TextAttention(nn.Module):
         self,
         x: mx.array,
         mask: Optional[mx.array] = None,
-        cache: Optional[Tuple[mx.array, mx.array]] = None,
+        cache: Optional[KVCache] = None,
     ) -> mx.array:
         B, L, D = x.shape
 
@@ -311,11 +316,9 @@ class TextAttention(nn.Module):
         values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
 
         if cache is not None:
-            key_cache, value_cache = cache
-            queries = self.rope(queries, offset=key_cache.shape[2])
-            keys = self.rope(keys, offset=key_cache.shape[2])
-            keys = mx.concatenate([key_cache, keys], axis=2)
-            values = mx.concatenate([value_cache, values], axis=2)
+            queries = self.rope(queries, offset=cache.offset)
+            keys = self.rope(keys, offset=cache.offset)
+            keys, values = cache.update_and_fetch(keys, values)
         else:
             queries = self.rope(queries)
             keys = self.rope(keys)
@@ -324,7 +327,7 @@ class TextAttention(nn.Module):
             queries, keys, values, scale=self.scale, mask=mask
         )
         output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
-        return self.o_proj(output), (keys, values)
+        return self.o_proj(output)
 
 
 class TextMLP(nn.Module):
@@ -355,13 +358,13 @@ class TransformerBlock(nn.Module):
         self,
         x: mx.array,
         mask: Optional[mx.array] = None,
-        cache: Optional[Tuple[mx.array, mx.array]] = None,
+        cache: Optional[KVCache] = None,
     ) -> mx.array:
-        r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
+        r = self.self_attn(self.input_layernorm(x), mask, cache)
         h = x + r
         r = self.mlp(self.post_attention_layernorm(h))
         out = h + r
-        return out, cache
+        return out
 
 
 class Llama(nn.Module):
@@ -370,6 +373,8 @@ class Llama(nn.Module):
         self.config = config
         self.vocab_size = config.vocab_size
         self.num_hidden_layers = config.num_hidden_layers
+        self.n_kv_heads = config.n_kv_heads
+        self.head_dim = config.head_dim
         assert self.vocab_size > 0
         self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
         self.layers = [
@@ -397,10 +402,11 @@ class Llama(nn.Module):
         if cache is None:
             cache = [None] * len(self.layers)
 
-        for e, layer in enumerate(self.layers):
-            h, cache[e] = layer(h, mask, cache[e])
 
-        return self.norm(h), cache
+        for layer, c in zip(self.layers, cache):
+            h = layer(h, mask, c)
+
+        return self.norm(h)
 
 
 class LanguageModel(nn.Module):
@@ -420,8 +426,8 @@ class LanguageModel(nn.Module):
         cache=None,
         inputs_embeds=None,
     ):
-        out, cache = self.model(inputs, cache, inputs_embeds)
-        return self.lm_head(out), cache
+        out = self.model(inputs, cache, inputs_embeds)
+        return self.lm_head(out)
 
     @staticmethod
     def sanitize(weights):
@@ -435,6 +441,7 @@ class LanguageModel(nn.Module):
 class LlaVAConfig:
     text_config: TextConfig
     vision_config: VisionConfig
+    model_type: str = "llava"
     ignore_index: int = -100
     image_token_index: int = 32000
     vision_feature_select_strategy: str = "default"
@@ -549,10 +556,10 @@ class LlavaModel(nn.Module):
 
     def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None):
         input_embddings = self.get_input_embeddings(input_ids, pixel_values)
-        logits, cache = self.language_model(
+        logits = self.language_model(
             input_ids, cache=cache, inputs_embeds=input_embddings
         )
-        return logits, cache
+        return logits
 
     @staticmethod
     def from_pretrained(path_or_hf_repo: str):

+ 9 - 4
exo/inference/mlx/sharded_model.py

@@ -57,9 +57,14 @@ class StatefulShardedModel:
         return self.step(x, temp, top_p, logit_bias)
 
     def reset(self):
+        if hasattr(self.model.config, "vision_config"):
+            model = self.model.language_model.model
+        else:
+            model = self.model
+
         kv_heads = (
-            [self.model.n_kv_heads] * len(self.model.layers)
-            if isinstance(self.model.n_kv_heads, int)
-            else self.model.n_kv_heads
+            [model.n_kv_heads] * len(model.layers)
+            if isinstance(model.n_kv_heads, int)
+            else model.n_kv_heads
         )
-        self.cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
+        self.cache = [KVCache(model.head_dim, n) for n in kv_heads]

+ 62 - 0
exo/inference/mlx/test_sharded_llava.py

@@ -0,0 +1,62 @@
+import torch
+import codecs
+import asyncio
+import requests
+from PIL import Image
+from io import BytesIO
+
+import mlx.core as mx
+from mlx_lm.models.base import KVCache
+
+from exo.inference.mlx.sharded_model import StatefulShardedModel
+from exo.inference.mlx.sharded_utils import load_shard_llava
+from exo.inference.shard import Shard
+
+def sample(logits, temperature=0.0):
+    if temperature == 0:
+        return mx.argmax(logits, axis=-1)
+    else:
+        return mx.random.categorical(logits * (1 / temperature))
+def generate_text(input_ids, pixel_values, model, processor, max_tokens, temperature):
+    kv_heads = (
+        [model.language_model.model.n_kv_heads] * len(model.language_model.model.layers)
+        if isinstance(model.language_model.model.n_kv_heads, int)
+        else model.language_model.model.n_kv_heads
+    )
+    cache = [KVCache(model.language_model.model.head_dim, n) for n in kv_heads]
+    logits = model(input_ids, pixel_values, cache=cache)
+    logits = logits[:, -1, :]
+    y = sample(logits, temperature=temperature)
+    tokens = [y.item()]
+
+    for n in range(max_tokens - 1):
+        logits = model.language_model(y[None], cache=cache)
+        logits = logits[:, -1, :]
+        y = sample(logits, temperature)
+        token = y.item()
+        if token == processor.tokenizer.eos_token_id:
+            break
+        tokens.append(token)
+
+    return processor.tokenizer.decode(tokens)
+
+shard_full = Shard("llava", 0, 31, 32)
+
+full_model_shard, full_processor = asyncio.run(load_shard_llava("llava-hf/llava-1.5-7b-hf", shard=shard_full))
+
+full = StatefulShardedModel(shard_full, full_model_shard)
+
+PROMPT = "USER: <image>\nWhat are these?\nASSISTANT:"
+IMAGE_FILE = "http://images.cocodataset.org/val2017/000000039769.jpg"
+response = requests.get(IMAGE_FILE)
+img = Image.open(BytesIO(response.content))
+prompt = codecs.decode(PROMPT, "unicode_escape")
+inputs = full_processor(prompt, img, return_tensors="np")
+pixel_values = mx.array(inputs["pixel_values"])
+input_ids = mx.array(inputs["input_ids"])
+
+print(prompt)
+generated_text = generate_text(
+    input_ids, pixel_values, full_model_shard, full_processor, 10, 0
+)
+print(generated_text)