Browse Source

Merge pull request #88 from varshith15/main

Support for LLaVA
Alex Cheema 1 year ago
parent
commit
0ec77e1a99

+ 2 - 0
.gitignore

@@ -2,6 +2,7 @@ __pycache__/
 .venv
 test_weights.npz
 .exo_used_ports
+.idea
 
 # Byte-compiled / optimized / DLL files
 __pycache__/
@@ -82,6 +83,7 @@ target/
 
 # Jupyter Notebook
 .ipynb_checkpoints
+Untitled.ipynb
 
 # IPython
 profile_default/

+ 29 - 3
README.md

@@ -27,7 +27,7 @@ Forget expensive NVIDIA GPUs, unify your existing devices into one powerful GPU:
 <div align="center">
   <h2>Update: Exo Supports Llama 3.1</h2>
   <p>Now the default models, run 8B, 70B and 405B parameter models on your own devices</p>
-  <p><a href="https://github.com/exo-explore/exo/blob/main/exo/inference/mlx/models/sharded_llama.py">See the code</a></p>
+  <p><a href="https://github.com/exo-explore/exo/blob/main/exo/inference/mlx/models/llama.py">See the code</a></p>
 </div>
 
 ## Get Involved
@@ -40,7 +40,7 @@ We also welcome contributions from the community. We have a list of bounties in
 
 ### Wide Model Support
 
-exo supports LLaMA ([MLX](exo/inference/mlx/models/sharded_llama.py) and [tinygrad](exo/inference/tinygrad/models/llama.py)) and other popular models.
+exo supports LLaMA ([MLX](exo/inference/mlx/models/llama.py) and [tinygrad](exo/inference/tinygrad/models/llama.py)) and other popular models.
 
 ### Dynamic Model Partitioning
 
@@ -111,7 +111,7 @@ The native way to access models running on exo is using the exo library with pee
 
 exo starts a ChatGPT-like WebUI (powered by [tinygrad tinychat](https://github.com/tinygrad/tinygrad/tree/master/examples/tinychat)) on http://localhost:8000
 
-For developers, exo also starts a ChatGPT-compatible API endpoint on http://localhost:8000/v1/chat/completions. Example with curl:
+For developers, exo also starts a ChatGPT-compatible API endpoint on http://localhost:8000/v1/chat/completions. Example with curls:
 
 ```sh
 curl http://localhost:8000/v1/chat/completions \
@@ -123,6 +123,32 @@ curl http://localhost:8000/v1/chat/completions \
    }'
 ```
 
+```sh
+curl http://localhost:8000/v1/chat/completions \
+  -H "Content-Type: application/json" \
+  -d '{
+     "model": "llava-1.5-7b-hf",
+     "messages": [
+      {
+        "role": "user",
+        "content": [
+          {
+            "type": "text",
+            "text": "What are these?"
+          },
+          {
+            "type": "image_url",
+            "image_url": {
+              "url": "http://images.cocodataset.org/val2017/000000039769.jpg"
+            }
+          }
+        ]
+      }
+    ],
+     "temperature": 0.0
+   }'
+```
+
 ## Debugging
 
 Enable debug logs with the DEBUG environment variable (0-9).

+ 69 - 12
exo/api/chatgpt_api.py

@@ -3,7 +3,7 @@ import time
 import asyncio
 import json
 from pathlib import Path
-from transformers import AutoTokenizer
+from transformers import AutoTokenizer, AutoProcessor
 from typing import List, Literal, Union, Dict
 from aiohttp import web
 import aiohttp_cors
@@ -42,11 +42,15 @@ shard_mappings = {
   "deepseek-coder-v2-lite": {
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27),
   },
+  ### llava
+  "llava-1.5-7b-hf": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32),
+  },
 }
 
 
 class Message:
-  def __init__(self, role: str, content: str):
+  def __init__(self, role: str, content: Union[str, list]):
     self.role = role
     self.content = content
 
@@ -68,6 +72,18 @@ def resolve_tinygrad_tokenizer(model_id: str):
 
 
 async def resolve_tokenizer(model_id: str):
+  try:
+    if DEBUG >= 2: print(f"Trying to AutoProcessor for {model_id}")
+    processor = AutoProcessor.from_pretrained(model_id)
+    processor.eos_token_id = processor.tokenizer.eos_token_id
+    processor.encode = processor.tokenizer.encode
+    return processor
+  except Exception as e:
+    if DEBUG >= 2: print(f"Failed to load processor for {model_id}. Error: {e}")
+    import traceback
+
+    if DEBUG >= 2: print(traceback.format_exc())
+
   try:
     if DEBUG >= 2: print(f"Trying AutoTokenizer for {model_id}")
     return AutoTokenizer.from_pretrained(model_id)
@@ -137,8 +153,50 @@ def generate_completion(
   return completion
 
 
-def build_prompt(tokenizer, messages: List[Message]):
-  return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+def remap_messages(messages: List[Message]) -> List[Message]:
+    remapped_messages = []
+    last_image = None
+    for message in messages:
+        remapped_content = []
+        for content in message.content:
+            if isinstance(content, dict):
+                if content.get("type") in ["image_url", "image"]:
+                    image_url = content.get("image_url", {}).get("url") or content.get("image")
+                    if image_url:
+                        last_image = {"type": "image", "image": image_url}
+                        remapped_content.append({"type": "text", "text": "[An image was uploaded but is not displayed here]"})
+                else:
+                    remapped_content.append(content)
+            else:
+                remapped_content.append({"type": "text", "text": content})
+        remapped_messages.append(Message(role=message.role, content=remapped_content))
+
+    if last_image:
+        # Replace the last image placeholder with the actual image content
+        for message in reversed(remapped_messages):
+            for i, content in enumerate(message.content):
+                if content.get("type") == "text" and content.get("text") == "[An image was uploaded but is not displayed here]":
+                    message.content[i] = last_image
+                    return remapped_messages
+
+    return remapped_messages
+
+def build_prompt(tokenizer, _messages: List[Message]):
+  messages = remap_messages(_messages)
+  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+  image_str = None
+  for message in messages:
+    if not isinstance(message.content, list):
+      continue
+
+    for content in message.content:
+      # note: we only support one image at a time right now. Multiple is possible. See: https://github.com/huggingface/transformers/blob/e68ec18ce224af879f22d904c7505a765fb77de3/docs/source/en/model_doc/llava.md?plain=1#L41
+      # follows the convention in https://platform.openai.com/docs/guides/vision
+      if content.get("type", None) == "image":
+        image_str = content.get("image", None)
+        break
+
+  return prompt, image_str
 
 
 def parse_message(data: dict):
@@ -160,7 +218,7 @@ class ChatGPTAPI:
     self.node = node
     self.inference_engine_classname = inference_engine_classname
     self.response_timeout_secs = response_timeout_secs
-    self.app = web.Application()
+    self.app = web.Application(client_max_size=100 * 1024 * 1024)  # 100MB to support image upload
     self.prev_token_lens: Dict[str, int] = {}
     self.stream_tasks: Dict[str, asyncio.Task] = {}
     cors = aiohttp_cors.setup(self.app)
@@ -187,7 +245,6 @@ class ChatGPTAPI:
     return middleware
 
   async def handle_root(self, request):
-    print(f"Handling root request from {request.remote}")
     return web.FileResponse(self.static_dir / "index.html")
 
   async def handle_post_chat_token_encode(self, request):
@@ -195,7 +252,7 @@ class ChatGPTAPI:
     shard = shard_mappings.get(data.get("model", "llama-3.1-8b"), {}).get(self.inference_engine_classname)
     messages = [parse_message(msg) for msg in data.get("messages", [])]
     tokenizer = await resolve_tokenizer(shard.model_id)
-    return web.json_response({"length": len(build_prompt(tokenizer, messages))})
+    return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})
 
   async def handle_post_chat_completions(self, request):
     data = await request.json()
@@ -219,13 +276,13 @@ class ChatGPTAPI:
     tokenizer = await resolve_tokenizer(shard.model_id)
     if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
 
-    prompt = build_prompt(tokenizer, chat_request.messages)
+    prompt, image_str = build_prompt(tokenizer, chat_request.messages)
     callback_id = f"chatgpt-api-wait-response-{request_id}"
     callback = self.node.on_token.register(callback_id)
 
-    if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
+    if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=} {image_str=}")
     try:
-      await self.node.process_prompt(shard, prompt, request_id=request_id)
+      await self.node.process_prompt(shard, prompt, image_str, request_id=request_id)
     except Exception as e:
       if DEBUG >= 2:
         import traceback
@@ -252,7 +309,7 @@ class ChatGPTAPI:
           self.prev_token_lens[request_id] = max(prev_last_tokens_len, len(tokens))
           new_tokens = tokens[prev_last_tokens_len:]
           finish_reason = None
-          eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
+          eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer, AutoTokenizer) else getattr(tokenizer, "eos_token_id", None)
           if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id:
             new_tokens = new_tokens[:-1]
             if is_finished:
@@ -294,7 +351,7 @@ class ChatGPTAPI:
         )
 
         finish_reason = "length"
-        eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
+        eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(getattr(tokenizer, "_tokenizer", None), AutoTokenizer) else tokenizer.eos_token_id
         if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
         if tokens[-1] == eos_token_id:
           tokens = tokens[:-1]

+ 633 - 0
exo/inference/mlx/models/llava.py

@@ -0,0 +1,633 @@
+# Copyright © 2024 Apple Inc.
+
+import math
+import inspect
+from dataclasses import dataclass, field
+from typing import Optional, Dict, Union
+
+import mlx.core as mx
+import mlx.nn as nn
+from mlx_lm.models.base import BaseModelArgs, KVCache
+from exo.inference.shard import Shard
+from .base import IdentityBlock
+import numpy as np
+
+
+@dataclass
+class VisionConfig:
+    model_type: str
+    num_hidden_layers: int = 24
+    hidden_size: int = 1024
+    intermediate_size: int = 4096
+    num_attention_heads: int = 16
+    image_size: int = 336
+    patch_size: int = 14
+    projection_dim: int = 768
+    vocab_size: int = 32000
+    num_channels: int = 3
+    layer_norm_eps: float = 1e-5
+
+    @classmethod
+    def from_dict(cls, params):
+        return cls(
+            **{
+                k: v
+                for k, v in params.items()
+                if k in inspect.signature(cls).parameters
+            }
+        )
+
+
+class VisionAttention(nn.Module):
+    def __init__(
+            self,
+            dims: int,
+            num_heads: int,
+            query_input_dims: Optional[int] = None,
+            key_input_dims: Optional[int] = None,
+            value_input_dims: Optional[int] = None,
+            value_dims: Optional[int] = None,
+            value_output_dims: Optional[int] = None,
+            bias: bool = False,
+    ):
+        super().__init__()
+
+        if (dims % num_heads) != 0:
+            raise ValueError(
+                "The input feature dimensions should be divisible by the "
+                f"number of heads ({dims} % {num_heads}) != 0"
+            )
+
+        query_input_dims = query_input_dims or dims
+        key_input_dims = key_input_dims or dims
+        value_input_dims = value_input_dims or key_input_dims
+        value_dims = value_dims or dims
+        value_output_dims = value_output_dims or dims
+
+        self.num_heads = num_heads
+        self.q_proj = nn.Linear(query_input_dims, dims, bias=bias)
+        self.k_proj = nn.Linear(key_input_dims, dims, bias=bias)
+        self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias)
+        self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias)
+
+    def __call__(self, queries, keys, values, mask=None):
+        queries = self.q_proj(queries)
+        keys = self.k_proj(keys)
+        values = self.v_proj(values)
+
+        num_heads = self.num_heads
+        B, L, D = queries.shape
+        _, S, _ = keys.shape
+        queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
+        keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
+        values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
+
+        scale = math.sqrt(1 / queries.shape[-1])
+        scores = (queries * scale) @ keys
+        if mask is not None:
+            scores = scores + mask.astype(scores.dtype)
+        scores = mx.softmax(scores, axis=-1)
+        values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
+
+        return self.out_proj(values_hat)
+
+
+class VisionMLP(nn.Module):
+    def __init__(self, config: VisionConfig):
+        super().__init__()
+        self.activation_fn = nn.GELU(approx="fast")
+        self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+        self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+    def __call__(self, x: mx.array) -> mx.array:
+        x = self.activation_fn(self.fc1(x))
+        x = self.fc2(x)
+        return x
+
+
+class VisionEncoderLayer(nn.Module):
+    def __init__(self, config: VisionConfig):
+        super().__init__()
+        self.embed_dim = config.hidden_size
+        self.self_attn = VisionAttention(
+            config.hidden_size, config.num_attention_heads, bias=True
+        )
+        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+        self.mlp = VisionMLP(config)
+        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+
+    def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
+        y = self.layer_norm1(x)
+        y = self.self_attn(y, y, y, mask)
+        x = x + y
+        y = self.layer_norm2(x)
+        y = self.mlp(y)
+        return x + y
+
+
+class VisionEncoder(nn.Module):
+    def __init__(self, config: VisionConfig):
+        super().__init__()
+        self.layers = [VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]
+
+
+class VisionEmbeddings(nn.Module):
+    def __init__(self, config: VisionConfig):
+        super().__init__()
+        self.config = config
+        self.embed_dim = config.hidden_size
+        self.image_size = config.image_size
+        self.patch_size = config.patch_size
+
+        self.class_embedding = mx.zeros((config.hidden_size,))
+
+        self.patch_embedding = nn.Conv2d(
+            in_channels=config.num_channels,
+            out_channels=self.embed_dim,
+            kernel_size=self.patch_size,
+            stride=self.patch_size,
+            bias=False,
+        )
+
+        self.num_patches = (self.image_size // self.patch_size) ** 2
+        self.num_positions = self.num_patches + 1
+        self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+
+    def __call__(self, x: mx.array) -> mx.array:
+        batch_size = x.shape[0]
+        patch_embeddings = self.patch_embedding(x)
+        patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2)
+        embed_dim = patch_embeddings.shape[-1]
+        cls_embeddings = mx.broadcast_to(
+            self.class_embedding, (batch_size, 1, embed_dim)
+        )
+        embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1)
+        embeddings += self.position_embedding.weight
+        return embeddings
+
+
+class ClipVisionModel(nn.Module):
+    def __init__(self, config: VisionConfig):
+        super().__init__()
+        self.embeddings = VisionEmbeddings(config)
+        self.pre_layrnorm = nn.LayerNorm(config.hidden_size)
+        self.encoder = VisionEncoder(config)
+        self.post_layernorm = nn.LayerNorm(config.hidden_size)
+
+    def __call__(
+        self,
+        x: mx.array,
+        output_hidden_states: Optional[bool] = None,
+    ) -> mx.array:
+        x = self.embeddings(x)
+        x = self.pre_layrnorm(x)
+
+        encoder_states = (x,) if output_hidden_states else None
+
+        for l in self.encoder.layers:
+            x = l(x, mask=None)
+            if output_hidden_states:
+                encoder_states = encoder_states + (x,)
+
+        pooler_output = self.post_layernorm(x[:, 0, :])
+        return pooler_output, x, encoder_states
+
+
+class VisionModel(nn.Module):
+    def __init__(self, config: VisionConfig):
+        super().__init__()
+
+        self.model_type = config.model_type
+        if self.model_type != "clip_vision_model":
+            raise ValueError(f"Unsupported model type: {self.model_type}")
+
+        self.vision_model = ClipVisionModel(config)
+
+    def __call__(
+            self, x: mx.array, output_hidden_states: Optional[bool] = None
+    ) -> mx.array:
+        return self.vision_model(x, output_hidden_states)
+
+    def sanitize(self, weights):
+        sanitized_weights = {}
+        for k, v in weights.items():
+            if "position_ids" in k:
+                # Remove unused position_ids
+                continue
+            elif "patch_embedding.weight" in k:
+                # PyTorch conv2d weight tensors have shape:
+                #   [out_channels, in_channels, kH, KW]
+                # MLX conv2d expects the weight be of shape:
+                #   [out_channels, kH, KW, in_channels]
+                sanitized_weights[k] = v.transpose(0, 2, 3, 1)
+            else:
+                sanitized_weights[k] = v
+
+        return sanitized_weights
+
+
+@dataclass
+class TextConfig:
+    model_type: str
+    hidden_size: int = 4096
+    num_hidden_layers: int = 32
+    intermediate_size: int = 11008
+    num_attention_heads: int = 32
+    head_dim: int = None
+    rms_norm_eps: float = 1e-6
+    vocab_size: int = 32000
+    num_key_value_heads: int = None
+    rope_theta: float = 10000
+    rope_traditional: bool = False
+    rope_scaling: Optional[Dict[str, Union[float, str]]] = None
+
+    @classmethod
+    def from_dict(cls, params):
+        return cls(
+            **{
+                k: v
+                for k, v in params.items()
+                if k in inspect.signature(cls).parameters
+            }
+        )
+
+    def __post_init__(self):
+        if self.num_key_value_heads is None:
+            self.num_key_value_heads = self.num_attention_heads
+
+        if self.head_dim is None:
+            self.head_dim = self.hidden_size // self.num_attention_heads
+
+        if self.model_type is None:
+            self.model_type = "llama"
+
+        if self.rope_scaling:
+            required_keys = {"factor", "type"}
+            if not all(key in self.rope_scaling for key in required_keys):
+                raise ValueError(f"rope_scaling must contain keys {required_keys}")
+
+            if self.rope_scaling["type"] != "linear":
+                raise ValueError("rope_scaling 'type' currently only supports 'linear'")
+
+
+class TextAttention(nn.Module):
+    def __init__(self, config: TextConfig):
+        super().__init__()
+
+        dim = config.hidden_size
+        self.n_heads = n_heads = config.num_attention_heads
+        self.n_kv_heads = n_kv_heads = config.num_key_value_heads
+
+        self.repeats = n_heads // n_kv_heads
+
+        head_dim = config.hidden_size // n_heads
+        self.scale = head_dim ** -0.5
+
+        self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
+        self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
+        self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
+        self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
+
+        rope_scale = (
+            1 / config.rope_scaling["factor"]
+            if config.rope_scaling is not None
+               and config.rope_scaling["type"] == "linear"
+            else 1
+        )
+        self.rope = nn.RoPE(
+            head_dim,
+            traditional=config.rope_traditional,
+            base=config.rope_theta,
+            scale=rope_scale,
+        )
+
+    def __call__(
+            self,
+            x: mx.array,
+            mask: Optional[mx.array] = None,
+            cache: Optional[KVCache] = None,
+    ) -> mx.array:
+        B, L, D = x.shape
+
+        queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
+
+        # Prepare the queries, keys and values for the attention computation
+        queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
+        keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
+        values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
+
+        if cache is not None:
+            queries = self.rope(queries, offset=cache.offset)
+            keys = self.rope(keys, offset=cache.offset)
+            keys, values = cache.update_and_fetch(keys, values)
+        else:
+            queries = self.rope(queries)
+            keys = self.rope(keys)
+
+        output = mx.fast.scaled_dot_product_attention(
+            queries, keys, values, scale=self.scale, mask=mask
+        )
+        output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
+        return self.o_proj(output)
+
+
+class TextMLP(nn.Module):
+    def __init__(self, dim, hidden_dim):
+        super().__init__()
+        self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
+        self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
+        self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
+
+    def __call__(self, x) -> mx.array:
+        return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
+
+
+class TransformerBlock(nn.Module):
+    def __init__(self, config: TextConfig):
+        super().__init__()
+        self.num_attention_heads = config.num_attention_heads
+        self.hidden_size = config.hidden_size
+        self.self_attn = TextAttention(config)
+        self.mlp = TextMLP(config.hidden_size, config.intermediate_size)
+        self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+        self.post_attention_layernorm = nn.RMSNorm(
+            config.hidden_size, eps=config.rms_norm_eps
+        )
+        self.config = config
+
+    def __call__(
+            self,
+            x: mx.array,
+            mask: Optional[mx.array] = None,
+            cache: Optional[KVCache] = None,
+    ) -> mx.array:
+        r = self.self_attn(self.input_layernorm(x), mask, cache)
+        h = x + r
+        r = self.mlp(self.post_attention_layernorm(h))
+        out = h + r
+        return out
+
+
+class Llama(nn.Module):
+    def __init__(self, config: TextConfig, shard: Shard):
+        super().__init__()
+        self.config = config
+        self.shard = shard
+        self.vocab_size = config.vocab_size
+        self.model_type = config.model_type
+        self.num_hidden_layers = config.num_hidden_layers
+        self.num_key_value_heads = config.num_key_value_heads
+        self.head_dim = config.head_dim
+        assert self.vocab_size > 0
+        if self.shard.is_first_layer():
+            self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
+        self.layers = []
+        for i in range(self.num_hidden_layers):
+          if self.shard.start_layer <= i <= self.shard.end_layer:
+            self.layers.append(TransformerBlock(config=config))
+          else:
+            self.layers.append(IdentityBlock())
+        if self.shard.is_last_layer():
+            self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+    def __call__(
+            self,
+            inputs: mx.array,
+            cache=None,
+            inputs_embeds=None,
+    ):
+        # for passing merged input embeddings
+        if inputs_embeds is None:
+            if self.shard.is_first_layer():
+                h = self.embed_tokens(inputs)
+            else:
+                h = inputs
+        else:
+            h = inputs_embeds
+
+        mask = None
+        if h.shape[1] > 1:
+            mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
+            mask = mask.astype(h.dtype)
+
+        if cache is None:
+            cache = [None] * len(self.layers)
+
+
+        for layer, c in zip(self.layers, cache):
+            h = layer(h, mask, c)
+
+        if self.shard.is_last_layer():
+            h = self.norm(h)
+        return h
+
+class LanguageModel(nn.Module):
+    def __init__(self, config: TextConfig, shard: Shard):
+        super().__init__()
+        self.model_type = config.model_type
+        if self.model_type != "llama":
+            raise ValueError(
+                f"Model type {self.model_type} not supported. Currently only 'llama' is supported"
+            )
+        self.shard = shard
+        self.model = Llama(config, shard)
+        if self.shard.is_last_layer():
+            self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+    def __call__(
+        self,
+        inputs: mx.array,
+        cache=None,
+        inputs_embeds=None,
+    ):
+        out = self.model(inputs, cache, inputs_embeds)
+        if self.shard.is_last_layer():
+            out = self.lm_head(out)
+        return out
+
+    def sanitize(self, weights):
+        shard_state_dict = {}
+        for key, value in weights.items():
+            if "self_attn.rotary_emb.inv_freq" in key:
+                continue
+
+            if key.startswith('language_model.model.layers.'):
+                layer_num = int(key.split('.')[3])
+                if layer_num < self.shard.start_layer or layer_num > self.shard.end_layer:
+                    continue
+            if not self.shard.is_first_layer() and key.startswith('language_model.model.embed_tokens'):
+                continue
+            elif not self.shard.is_last_layer() and (key.startswith('language_model.model.norm') or key.startswith('language_model.lm_head')):
+                continue
+
+            shard_state_dict[key] = value
+
+        return shard_state_dict
+
+@dataclass
+class LlaVAConfig(BaseModelArgs):
+    text_config: TextConfig
+    vision_config: VisionConfig = None
+    model_type: str = "llava"
+    ignore_index: int = -100
+    image_token_index: int = 32000
+    vision_feature_select_strategy: str = "default"
+    vision_feature_layer: int = -2
+    vocab_size: int = 32000
+
+    @classmethod
+    def from_dict(cls, params):
+        updated_params = {}
+        class_params = inspect.signature(cls).parameters
+        for k, v in params.items():
+            if k in class_params:
+                if k in ["text_config", "vision_config"]:
+                    v = class_params[k].annotation.from_dict(v)
+                updated_params.update({k: v})
+
+        return cls(**updated_params)
+
+
+@dataclass
+class ModelArgs(LlaVAConfig):
+    shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
+
+    def __post_init__(self):
+        if isinstance(self.shard, dict):
+            self.shard = Shard(**self.shard)
+
+        if not isinstance(self.shard, Shard):
+            raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
+
+        if not self.shard.is_first_layer():
+            self.vision_config = None
+
+
+class LlavaMultiModalProjector(nn.Module):
+    def __init__(self, config: LlaVAConfig):
+        super().__init__()
+        self.linear_1 = nn.Linear(
+            config.vision_config.hidden_size, config.text_config.hidden_size, bias=True
+        )
+        self.gelu = nn.GELU()
+        self.linear_2 = nn.Linear(
+            config.text_config.hidden_size, config.text_config.hidden_size, bias=True
+        )
+
+    def __call__(self, x: mx.array) -> mx.array:
+        x = self.linear_1(x)
+        x = self.gelu(x)
+        x = self.linear_2(x)
+        return x
+
+
+class Model(nn.Module):
+    def __init__(self, config: ModelArgs):
+        super().__init__()
+        self.config = config
+        self.model_type = config.model_type
+        if config.vision_config:
+            self.vision_tower = VisionModel(config.vision_config)
+            self.multi_modal_projector = LlavaMultiModalProjector(config)
+            self.vision_feature_layer = config.vision_feature_layer
+            self.vision_feature_select_strategy = config.vision_feature_select_strategy
+        self.language_model = LanguageModel(config.text_config, config.shard)
+
+    def get_input_embeddings(
+            self,
+            input_ids: Optional[mx.array] = None,
+            pixel_values: Optional[mx.array] = None,
+    ):
+        if pixel_values is None:
+            return self.language_model(input_ids)
+
+        # Get the input embeddings from the language model
+        inputs_embeds = self.language_model.model.embed_tokens(input_ids)
+
+        # Get the ouptut hidden states from the vision model
+        *_, hidden_states = self.vision_tower(
+            pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True
+        )
+
+        # Select the hidden states from the desired layer
+        selected_image_feature = hidden_states[self.vision_feature_layer]
+
+        if self.vision_feature_select_strategy == "default":
+            selected_image_feature = selected_image_feature[:, 1:]
+        elif self.vision_feature_select_strategy == "full":
+            selected_image_feature = selected_image_feature
+        else:
+            raise ValueError(
+                "Unexpected feature selection strategy: "
+                f"{self.vision_feature_select_strategy}"
+            )
+
+        # Pass image features through the multi-modal projector
+        image_features = self.multi_modal_projector(selected_image_feature)
+
+        # Insert special image tokens in the input_ids
+        final_inputs_embeds = self._merge_input_ids_with_image_features(
+            image_features, inputs_embeds, input_ids
+        )
+        return final_inputs_embeds
+
+    def _merge_input_ids_with_image_features(
+            self, image_features, inputs_embeds, input_ids
+    ):
+        image_token_index = self.config.image_token_index
+        num_images, num_image_patches, embed_dim = image_features.shape
+
+        # Positions of <image> tokens in input_ids, assuming batch size is 1
+        image_positions = np.where(input_ids[0] == image_token_index)[0].tolist()
+
+        if len(image_positions) != num_images:
+            raise ValueError(
+                f"The number of image tokens ({len(image_positions)}) does not "
+                f" match the number of image inputs ({num_images})."
+            )
+
+        text_segments = []
+        start_idx = 0
+
+        for position in image_positions:
+            text_segments.append(inputs_embeds[:, start_idx:position])
+            start_idx = position + 1
+
+        image_embeddings = mx.split(image_features, image_features.shape[0])
+        final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p]
+        final_embeddings += [inputs_embeds[:, start_idx:]]
+
+        # Create a final embedding of shape
+        # (1, num_image_patches*num_images + sequence_len, embed_dim)
+        return mx.concatenate(final_embeddings, axis=1)
+
+    def __call__(self, input_ids: mx.array, pixel_values: mx.array = None, cache=None):
+        input_embddings = None
+        if pixel_values is not None:
+            input_embddings = self.get_input_embeddings(input_ids, pixel_values)
+        logits = self.language_model(
+            input_ids, cache=cache, inputs_embeds=input_embddings
+        )
+        return logits
+
+    def sanitize(self, weights):
+        if self.config.vision_config:
+          weights = self.vision_tower.sanitize(weights)
+        else:
+          weights = {k: v for k, v in weights.items() if not k.startswith(('vision_tower', 'multi_modal_projector', 'vision_feature_layer', 'vision_feature_select_strategy'))}
+        weights = self.language_model.sanitize(weights)
+        return weights
+
+    @property
+    def layers(self):
+        return self.language_model.model.layers
+
+    @property
+    def head_dim(self):
+        return (
+                self.language_model.model.head_dim or self.language_model.model.hidden_size // self.language_model.model.num_attention_heads
+        )
+
+    @property
+    def n_kv_heads(self):
+        return self.language_model.model.num_key_value_heads

+ 10 - 3
exo/inference/mlx/sharded_inference_engine.py

@@ -2,7 +2,7 @@ import numpy as np
 import mlx.core as mx
 from ..inference_engine import InferenceEngine
 from .sharded_model import StatefulShardedModel
-from .sharded_utils import load_shard
+from .sharded_utils import load_shard, get_image_from_str
 from ..shard import Shard
 from typing import Optional
 
@@ -11,9 +11,16 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
   def __init__(self):
     self.shard = None
 
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
+  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     await self.ensure_shard(shard)
-    output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, mx.array(self.tokenizer.encode(prompt))))
+    if image_str:
+      image = await get_image_from_str(image_str)
+      inputs = self.tokenizer(prompt, image, return_tensors="np")
+      pixel_values = mx.array(inputs["pixel_values"])
+      input_ids = mx.array(inputs["input_ids"])
+      output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, input_ids, pixel_values))
+    else:
+      output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, mx.array(self.tokenizer.encode(prompt))))
     return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
 
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):

+ 6 - 1
exo/inference/mlx/sharded_model.py

@@ -18,6 +18,7 @@ class StatefulShardedModel:
     self,
     request_id: str,
     x,
+    pixel_values=None,
     temp: float = 0.0,
     top_p: float = 1.0,
     logit_bias: Optional[Dict[int, float]] = None,
@@ -42,7 +43,11 @@ class StatefulShardedModel:
 
     if request_id not in self.request_cache:
       self.init_cache(request_id)
-    output = self.model(y[None] if self.shard.is_first_layer() else y, cache=self.request_cache[request_id])
+
+    if pixel_values is None:
+      output = self.model(y[None] if self.shard.is_first_layer() else y, cache=self.request_cache[request_id])
+    else:
+      output = self.model(y, pixel_values=pixel_values, cache=self.request_cache[request_id])
 
     if self.shard.is_last_layer():
       logits = output[:, -1, :]

+ 41 - 3
exo/inference/mlx/sharded_utils.py

@@ -5,14 +5,21 @@ import importlib
 import json
 import logging
 import asyncio
+import aiohttp
 from functools import partial
 from pathlib import Path
 from typing import Optional, Tuple
+import requests
+from PIL import Image
+from io import BytesIO
+import base64
 
+from exo import DEBUG
 import mlx.core as mx
 import mlx.nn as nn
 from huggingface_hub import snapshot_download
 from huggingface_hub.utils._errors import RepositoryNotFoundError
+from transformers import AutoProcessor
 
 from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
 from mlx_lm.tuner.utils import apply_lora_layers
@@ -128,7 +135,7 @@ def load_model_shard(
       class_predicate=None,
     )
 
-  model.load_weights(list(weights.items()))
+  model.load_weights(list(weights.items()), strict=True)
 
   if not lazy:
     mx.eval(model.parameters())
@@ -217,6 +224,37 @@ async def load_shard(
   if adapter_path is not None:
     model = apply_lora_layers(model, adapter_path)
     model.eval()
-  tokenizer = load_tokenizer(model_path, tokenizer_config)
 
-  return model, tokenizer
+  # TODO: figure out a generic solution
+  if model.model_type == "llava":
+    processor = AutoProcessor.from_pretrained(model_path)
+    processor.eos_token_id = processor.tokenizer.eos_token_id
+    processor.encode = processor.tokenizer.encode
+    return model, processor
+  else:
+    tokenizer = load_tokenizer(model_path, tokenizer_config)
+    return model, tokenizer
+
+async def get_image_from_str(_image_str: str):
+    image_str = _image_str.strip()
+
+    if image_str.startswith("http"):
+        async with aiohttp.ClientSession() as session:
+            async with session.get(image_str, timeout=10) as response:
+                content = await response.read()
+                return Image.open(BytesIO(content)).convert("RGB")
+    elif image_str.startswith("data:image/"):
+        # Extract the image format and base64 data
+        format_prefix, base64_data = image_str.split(";base64,")
+        image_format = format_prefix.split("/")[1].lower()
+        if DEBUG >= 2: print(f"{image_str=} {image_format=}")
+        imgdata = base64.b64decode(base64_data)
+        img = Image.open(BytesIO(imgdata))
+
+        # Convert to RGB if not already
+        if img.mode != "RGB":
+            img = img.convert("RGB")
+
+        return img
+    else:
+        raise ValueError("Invalid image_str format. Must be a URL or a base64 encoded image.")

+ 64 - 0
exo/inference/mlx/test_sharded_llava.py

@@ -0,0 +1,64 @@
+import codecs
+import asyncio
+import requests
+from PIL import Image
+from io import BytesIO
+
+import mlx.core as mx
+from mlx_lm.models.base import KVCache
+
+from exo.inference.mlx.sharded_model import StatefulShardedModel
+from exo.inference.mlx.sharded_utils import load_shard
+from exo.inference.shard import Shard
+
+shard_full = Shard("llava", 0, 31, 32)
+shard1 = Shard("llava", 0, 12, 32)
+shard2 = Shard("llava", 13, 31, 32)
+
+model_path = "llava-hf/llava-1.5-7b-hf"
+
+full_model_shard, full_processor = asyncio.run(load_shard(model_path, shard=shard_full))
+model_shard1, processor1 = asyncio.run(load_shard(model_path, shard=shard1))
+model_shard2, processor2 = asyncio.run(load_shard(model_path, shard=shard2))
+
+full = StatefulShardedModel(shard_full, full_model_shard)
+m1 = StatefulShardedModel(shard1, model_shard1)
+m2 = StatefulShardedModel(shard2, model_shard2)
+
+PROMPT = "USER: <image>\nWhat are these?\nASSISTANT:"
+IMAGE_FILE = "http://images.cocodataset.org/val2017/000000039769.jpg"
+response = requests.get(IMAGE_FILE)
+img = Image.open(BytesIO(response.content))
+prompt = codecs.decode(PROMPT, "unicode_escape")
+inputs = full_processor(prompt, img, return_tensors="np")
+pixel_values = mx.array(inputs["pixel_values"])
+input_ids = mx.array(inputs["input_ids"])
+
+print(prompt)
+y = full.step("full", input_ids, pixel_values, temp=0)
+full_generated_tokens = [y.item()]
+
+for _ in range(13):
+    y = full.step("full", y, temp=0)
+    full_generated_tokens.append(y.item())
+
+full_response = full_processor.tokenizer.decode(full_generated_tokens)
+print("full response:", full_response)
+
+inputs = processor1(prompt, img, return_tensors="np")
+pixel_values = mx.array(inputs["pixel_values"])
+input_ids = mx.array(inputs["input_ids"])
+
+y = m1.step("shard", input_ids, pixel_values, temp=0)
+y = m2.step("shard", y, temp=0)
+full_generated_tokens = [y.item()]
+
+for _ in range(13):
+    y = m1.step("shard", y, temp=0)
+    y = m2.step("shard", y, temp=0)
+    full_generated_tokens.append(y.item())
+
+sharded_response = processor2.tokenizer.decode(full_generated_tokens)
+print("sharded response:", sharded_response)
+
+assert full_response == sharded_response

+ 3 - 0
exo/inference/shard.py

@@ -14,6 +14,9 @@ class Shard:
   def is_last_layer(self) -> bool:
     return self.end_layer == self.n_layers - 1
 
+  def get_layer_count(self) -> int:
+    return self.end_layer - self.start_layer + 1
+
   def to_dict(self) -> dict:
     return {
       "model_id": self.model_id,

+ 2 - 1
exo/networking/grpc/grpc_peer_handle.py

@@ -39,9 +39,10 @@ class GRPCPeerHandle(PeerHandle):
     self.channel = None
     self.stub = None
 
-  async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
+  async def send_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
     request = node_service_pb2.PromptRequest(
       prompt=prompt,
+      image_str=image_str,
       shard=node_service_pb2.Shard(
         model_id=shard.model_id,
         start_layer=shard.start_layer,

+ 3 - 2
exo/networking/grpc/grpc_server.py

@@ -45,9 +45,10 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
       n_layers=request.shard.n_layers,
     )
     prompt = request.prompt
+    image_str = request.image_str
     request_id = request.request_id
-    result = await self.node.process_prompt(shard, prompt, request_id)
-    if DEBUG >= 2: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
+    result = await self.node.process_prompt(shard, prompt, image_str, request_id)
+    if DEBUG >= 2: print(f"SendPrompt {shard=} {prompt=} {image=} {request_id=} result: {result}")
     tensor_data = result.tobytes() if result is not None else None
     return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
 

+ 3 - 2
exo/networking/grpc/node_service.proto

@@ -21,8 +21,9 @@ message Shard {
 message PromptRequest {
   Shard shard = 1;
   string prompt = 2;
-  optional string request_id = 3;
-  optional string inference_state = 4;
+  optional string image_str = 3;
+  optional string request_id = 4;
+  optional string inference_state = 5;
 }
 
 message TensorRequest {

File diff suppressed because it is too large
+ 0 - 0
exo/networking/grpc/node_service_pb2.py


+ 1 - 1
exo/networking/peer_handle.py

@@ -28,7 +28,7 @@ class PeerHandle(ABC):
     pass
 
   @abstractmethod
-  async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
+  async def send_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
     pass
 
   @abstractmethod

+ 1 - 1
exo/orchestration/node.py

@@ -16,7 +16,7 @@ class Node(ABC):
     pass
 
   @abstractmethod
-  async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+  async def process_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
     pass
 
   @abstractmethod

+ 12 - 9
exo/orchestration/standard_node.py

@@ -69,7 +69,7 @@ class StandardNode(Node):
     await self.discovery.stop()
     await self.server.stop()
 
-  async def process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+  async def process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
     shard = self.get_current_shard(base_shard)
     asyncio.create_task(
       self.broadcast_opaque_status(
@@ -82,6 +82,7 @@ class StandardNode(Node):
             "base_shard": base_shard.to_dict(),
             "shard": shard.to_dict(),
             "prompt": prompt,
+            "image_str": image_str,
             "inference_state": inference_state,
             "request_id": request_id,
           }
@@ -89,7 +90,7 @@ class StandardNode(Node):
       )
     )
     start_time = time.perf_counter_ns()
-    resp = await self._process_prompt(base_shard, prompt, request_id, inference_state)
+    resp = await self._process_prompt(base_shard, prompt, image_str, request_id, inference_state)
     end_time = time.perf_counter_ns()
     elapsed_time_ns = end_time - start_time
     asyncio.create_task(
@@ -103,6 +104,7 @@ class StandardNode(Node):
             "base_shard": base_shard.to_dict(),
             "shard": shard.to_dict(),
             "prompt": prompt,
+            "image_str": image_str,
             "inference_state": inference_state,
             "request_id": request_id,
             "elapsed_time_ns": elapsed_time_ns,
@@ -113,20 +115,20 @@ class StandardNode(Node):
     )
     return resp
 
-  async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+  async def _process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
     if request_id is None:
       request_id = str(uuid.uuid4())
     if request_id not in self.buffered_token_output:
       self.buffered_token_output[request_id] = ([], False)
     shard = self.get_current_shard(base_shard)
 
-    if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
+    if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=} {image_str=}")
     if shard.start_layer != 0:
-      if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
-      await self.forward_to_next_shard(shard, prompt, request_id)
+      if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=} {image_str=}")
+      await self.forward_to_next_shard(shard, prompt, request_id, image_str)
       return
 
-    result, inference_state, is_finished = await self.inference_engine.infer_prompt(request_id, shard, prompt, inference_state=inference_state)
+    result, inference_state, is_finished = await self.inference_engine.infer_prompt(request_id, shard, prompt, image_str, inference_state=inference_state)
     is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
     if is_finished:
       self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
@@ -234,6 +236,7 @@ class StandardNode(Node):
     base_shard: Shard,
     tensor_or_prompt: Union[np.ndarray, str],
     request_id: str,
+    image_str: Optional[str] = None,
     inference_state: Optional[str] = None,
   ) -> None:
     if not self.partitioning_strategy:
@@ -255,7 +258,7 @@ class StandardNode(Node):
         if isinstance(tensor_or_prompt, np.ndarray):
           await self.process_tensor(shard, tensor_or_prompt, request_id, inference_state=inference_state)
         else:
-          await self.process_prompt(shard, tensor_or_prompt, request_id, inference_state=inference_state)
+          await self.process_prompt(shard, tensor_or_prompt, image_str, request_id, inference_state=inference_state)
         return
 
       target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None)
@@ -267,7 +270,7 @@ class StandardNode(Node):
       if isinstance(tensor_or_prompt, np.ndarray):
         await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state)
       else:
-        await target_peer.send_prompt(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state)
+        await target_peer.send_prompt(next_shard, tensor_or_prompt, image_str=image_str, request_id=request_id, inference_state=inference_state)
 
   def get_current_shard(self, base_shard: Shard) -> Shard:
     partitions = self.partitioning_strategy.partition(self.topology)

+ 3 - 2
setup.py

@@ -12,6 +12,7 @@ install_requires = [
     "huggingface-hub==0.23.4",
     "Jinja2==3.1.4",
     "numpy==2.0.0",
+    "pillow==10.4.0",
     "prometheus-client==0.20.0",
     "protobuf==5.27.1",
     "psutil==6.0.0",
@@ -22,7 +23,7 @@ install_requires = [
     "tiktoken==0.7.0",
     "tokenizers==0.19.1",
     "tqdm==4.66.4",
-    "transformers==4.41.2",
+    "transformers==4.43.3",
     "uuid==1.30",
     "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@639af3f823cf242a1945dc24183e52a9df0af2b7",
 ]
@@ -41,7 +42,7 @@ extras_require = {
         "pylint==3.2.6",
         "ruff==0.5.5",
         "mypy==1.11.0",
-    ]
+    ],
 }
 
 setup(

+ 89 - 0
tinychat/examples/tinychat/index.css

@@ -291,3 +291,92 @@ p {
 .monospace {
   font-family: monospace;
 }
+
+.model-selector {
+  display: flex;
+  justify-content: center;
+  padding: 20px 0;
+}
+.model-selector select {
+  padding: 10px 20px;
+  font-size: 16px;
+  border: 1px solid #ccc;
+  border-radius: 5px;
+  background-color: #f8f8f8;
+  cursor: pointer;
+}
+.model-selector select:focus {
+  outline: none;
+  border-color: #007bff;
+  box-shadow: 0 0 0 2px rgba(0,123,255,.25);
+}
+
+/* Image upload button styles */
+.image-input-button {
+  background-color: var(--secondary-color);
+  color: var(--foreground-color);
+  border: none;
+  border-radius: 50%;
+  width: 40px;
+  height: 40px;
+  font-size: 18px;
+  cursor: pointer;
+  transition: all 0.3s ease;
+  display: flex;
+  align-items: center;
+  justify-content: center;
+  margin-right: 10px;
+}
+
+.image-input-button:hover {
+  background-color: var(--secondary-color-transparent);
+  transform: scale(1.1);
+}
+
+.image-input-button:focus {
+  outline: none;
+  box-shadow: 0 0 0 3px rgba(var(--secondary-color-rgb), 0.5);
+}
+
+.image-input-button i {
+  transition: all 0.3s ease;
+}
+
+.image-input-button:hover i {
+  transform: scale(1.2);
+}
+
+/* Hidden file input styles */
+#image-upload {
+  display: none;
+}
+
+.image-preview-container {
+  position: relative;
+  display: inline-block;
+  margin-right: 10px;
+}
+
+.image-preview {
+  max-width: 100px;
+  max-height: 100px;
+  object-fit: cover;
+  border-radius: 5px;
+}
+
+.remove-image-button {
+  position: absolute;
+  top: -5px;
+  right: -5px;
+  background-color: rgba(255, 255, 255, 0.8);
+  border: none;
+  border-radius: 50%;
+  padding: 2px 5px;
+  cursor: pointer;
+}
+
+.message > p > img {
+  max-width: 100%;
+  max-height: 100%;
+  object-fit: contain;
+}

+ 12 - 22
tinychat/examples/tinychat/index.html

@@ -30,27 +30,6 @@
 
   <link rel="stylesheet" href="index.css">
   <link rel="stylesheet" href="common.css">
-
-  <style>
-    .model-selector {
-      display: flex;
-      justify-content: center;
-      padding: 20px 0;
-    }
-    .model-selector select {
-      padding: 10px 20px;
-      font-size: 16px;
-      border: 1px solid #ccc;
-      border-radius: 5px;
-      background-color: #f8f8f8;
-      cursor: pointer;
-    }
-    .model-selector select:focus {
-      outline: none;
-      border-color: #007bff;
-      box-shadow: 0 0 0 2px rgba(0,123,255,.25);
-    }
-  </style>
 </head>
 
 <body>
@@ -65,6 +44,7 @@
         <option value="mistral-nemo">Mistral Nemo</option>
         <option value="mistral-large">Mistral Large</option>
         <option value="deepseek-coder-v2-lite">Deepseek Coder V2 Lite</option>
+        <option value="llava-1.5-7b-hf">LLaVa 1.5 7B (Vision Model)</option>
       </select>
     </div>
     <div class="home centered" x-show="home === 0" x-transition x-effect="
@@ -180,6 +160,16 @@
         </span>
       </div>
       <div class="input">
+        <button x-show="cstate.selectedModel === 'llava-1.5-7b-hf'" class="image-input-button" @click="$refs.imageUpload.click()">
+          <i class="fas fa-image"></i>
+        </button>
+        <input x-ref="imageUpload" type="file" id="image-upload" accept="image/*" @change="$data.handleImageUpload($event)" style="display: none;">
+        <div x-show="imagePreview" class="image-preview-container">
+          <img :src="imagePreview" alt="Uploaded Image" class="image-preview">
+          <button @click="imagePreview = null; imageUrl = null;" class="remove-image-button">
+            <i class="fas fa-times"></i>
+          </button>
+        </div>
         <textarea x-ref="inputForm" id="input-form" class="input-form" autofocus rows=1 x-autosize
           :placeholder="generating ? 'Generating...' : 'Say something'" :disabled="generating" @input="
             home = (home === 0) ? 1 : home
@@ -208,4 +198,4 @@
   </main>
 </body>
 
-</html>
+</html>

+ 68 - 3
tinychat/examples/tinychat/index.js

@@ -4,6 +4,7 @@ document.addEventListener("alpine:init", () => {
     cstate: {
       time: null,
       messages: [],
+      selectedModel: 'llama-3.1-8b',
     },
 
     // historical state
@@ -18,6 +19,9 @@ document.addEventListener("alpine:init", () => {
     tokens_per_second: 0,
     total_tokens: 0,
 
+    // image handling
+    imagePreview: null,
+
     removeHistory(cstate) {
       const index = this.histories.findIndex((state) => {
         return state.time === cstate.time;
@@ -28,10 +32,28 @@ document.addEventListener("alpine:init", () => {
       }
     },
 
+    async handleImageUpload(event) {
+      const file = event.target.files[0];
+      if (file) {
+        const reader = new FileReader();
+        reader.onload = (e) => {
+          this.imagePreview = e.target.result;
+          this.imageUrl = e.target.result; // Store the image URL
+          // Add image preview to the chat
+          this.cstate.messages.push({
+            role: "user",
+            content: `![Uploaded Image](${this.imagePreview})`,
+          });
+        };
+        reader.readAsDataURL(file);
+      }
+    },
+
+
     async handleSend() {
       const el = document.getElementById("input-form");
       const value = el.value.trim();
-      if (!value) return;
+      if (!value && !this.imagePreview) return;
 
       if (this.generating) return;
       this.generating = true;
@@ -41,7 +63,9 @@ document.addEventListener("alpine:init", () => {
       window.history.pushState({}, "", "/");
 
       // add message to list
-      this.cstate.messages.push({ role: "user", content: value });
+      if (value) {
+        this.cstate.messages.push({ role: "user", content: value });
+      }
 
       // clear textarea
       el.value = "";
@@ -54,10 +78,51 @@ document.addEventListener("alpine:init", () => {
       let tokens = 0;
       this.tokens_per_second = 0;
 
+      // prepare messages for API request
+      const apiMessages = this.cstate.messages.map(msg => {
+        if (msg.content.startsWith('![Uploaded Image]')) {
+          return {
+            role: "user",
+            content: [
+              {
+                type: "image_url",
+                image_url: {
+                  url: this.imageUrl
+                }
+              }
+            ]
+          };
+        } else {
+          return {
+            role: msg.role,
+            content: [
+              {
+                type: "text",
+                text: msg.content
+              }
+            ]
+          };
+        }
+      });
+
+      // If there's an image URL, add it to all messages
+      if (this.imageUrl) {
+        apiMessages.forEach(msg => {
+          if (!msg.content.some(content => content.type === "image_url")) {
+            msg.content.push({
+              type: "image_url",
+              image_url: {
+                url: this.imageUrl
+              }
+            });
+          }
+        });
+      }
+
       // start receiving server sent events
       let gottenFirstChunk = false;
       for await (
-        const chunk of this.openaiChatCompletion(this.cstate.selectedModel, this.cstate.messages)
+        const chunk of this.openaiChatCompletion(this.cstate.selectedModel, apiMessages)
       ) {
         if (!gottenFirstChunk) {
           this.cstate.messages.push({ role: "assistant", content: "" });

Some files were not shown because too many files changed in this diff