Przeglądaj źródła

increase max request size to send raw images, make image download from url async, use chatgpt-compatible convention for images

Alex Cheema 9 miesięcy temu
rodzic
commit
0d45a855fb

+ 4 - 2
README.md

@@ -137,8 +137,10 @@ curl http://localhost:8000/v1/chat/completions \
             "text": "What are these?"
           },
           {
-            "type": "image",
-            "image": "http://images.cocodataset.org/val2017/000000039769.jpg"
+            "type": "image_url",
+            "image_url": {
+              "url": "http://images.cocodataset.org/val2017/000000039769.jpg"
+            }
           }
         ]
       }

+ 35 - 5
exo/api/chatgpt_api.py

@@ -153,14 +153,45 @@ def generate_completion(
   return completion
 
 
-def build_prompt(tokenizer, messages: List[Message]):
-  prompt =  tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+def remap_messages(messages: List[Message]) -> List[Message]:
+    remapped_messages = []
+    last_image = None
+    for message in messages:
+        remapped_content = []
+        for content in message.content:
+            if isinstance(content, dict):
+                if content.get("type") in ["image_url", "image"]:
+                    image_url = content.get("image_url", {}).get("url") or content.get("image")
+                    if image_url:
+                        last_image = {"type": "image", "image": image_url}
+                        remapped_content.append({"type": "text", "text": "[An image was uploaded but is not displayed here]"})
+                else:
+                    remapped_content.append(content)
+            else:
+                remapped_content.append({"type": "text", "text": content})
+        remapped_messages.append(Message(role=message.role, content=remapped_content))
+
+    if last_image:
+        # Replace the last image placeholder with the actual image content
+        for message in reversed(remapped_messages):
+            for i, content in enumerate(message.content):
+                if content.get("type") == "text" and content.get("text") == "[An image was uploaded but is not displayed here]":
+                    message.content[i] = last_image
+                    return remapped_messages
+
+    return remapped_messages
+
+def build_prompt(tokenizer, _messages: List[Message]):
+  messages = remap_messages(_messages)
+  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
   image_str = None
   for message in messages:
     if not isinstance(message.content, list):
       continue
 
     for content in message.content:
+      # note: we only support one image at a time right now. Multiple is possible. See: https://github.com/huggingface/transformers/blob/e68ec18ce224af879f22d904c7505a765fb77de3/docs/source/en/model_doc/llava.md?plain=1#L41
+      # follows the convention in https://platform.openai.com/docs/guides/vision
       if content.get("type", None) == "image":
         image_str = content.get("image", None)
         break
@@ -187,7 +218,7 @@ class ChatGPTAPI:
     self.node = node
     self.inference_engine_classname = inference_engine_classname
     self.response_timeout_secs = response_timeout_secs
-    self.app = web.Application()
+    self.app = web.Application(client_max_size=100 * 1024 * 1024)  # 100MB to support image upload
     self.prev_token_lens: Dict[str, int] = {}
     self.stream_tasks: Dict[str, asyncio.Task] = {}
     cors = aiohttp_cors.setup(self.app)
@@ -214,7 +245,6 @@ class ChatGPTAPI:
     return middleware
 
   async def handle_root(self, request):
-    print(f"Handling root request from {request.remote}")
     return web.FileResponse(self.static_dir / "index.html")
 
   async def handle_post_chat_token_encode(self, request):
@@ -279,7 +309,7 @@ class ChatGPTAPI:
           self.prev_token_lens[request_id] = max(prev_last_tokens_len, len(tokens))
           new_tokens = tokens[prev_last_tokens_len:]
           finish_reason = None
-          eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
+          eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer, AutoTokenizer) else getattr(tokenizer, "eos_token_id", None)
           if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id:
             new_tokens = new_tokens[:-1]
             if is_finished:

+ 1 - 1
exo/inference/mlx/sharded_inference_engine.py

@@ -14,7 +14,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     await self.ensure_shard(shard)
     if image_str:
-      image = get_image_from_str(image_str)
+      image = await get_image_from_str(image_str)
       inputs = self.tokenizer(prompt, image, return_tensors="np")
       pixel_values = mx.array(inputs["pixel_values"])
       input_ids = mx.array(inputs["input_ids"])

+ 26 - 8
exo/inference/mlx/sharded_utils.py

@@ -5,13 +5,16 @@ import importlib
 import json
 import logging
 import asyncio
+import aiohttp
 from functools import partial
 from pathlib import Path
 from typing import Optional, Tuple
 import requests
 from PIL import Image
 from io import BytesIO
+import base64
 
+from exo import DEBUG
 import mlx.core as mx
 import mlx.nn as nn
 from huggingface_hub import snapshot_download
@@ -232,11 +235,26 @@ async def load_shard(
     tokenizer = load_tokenizer(model_path, tokenizer_config)
     return model, tokenizer
 
-def get_image_from_str(image_str: str):
-  if image_str.startswith("http"):
-    response = requests.get(image_str, timeout=10)
-    image = Image.open(BytesIO(response.content)).convert("RGB")
-  else:
-    imgdata = base64.b64decode(image_str)
-    image = Image.open(io.BytesIO(imgdata))
-  return image
+async def get_image_from_str(_image_str: str):
+    image_str = _image_str.strip()
+
+    if image_str.startswith("http"):
+        async with aiohttp.ClientSession() as session:
+            async with session.get(image_str, timeout=10) as response:
+                content = await response.read()
+                return Image.open(BytesIO(content)).convert("RGB")
+    elif image_str.startswith("data:image/"):
+        # Extract the image format and base64 data
+        format_prefix, base64_data = image_str.split(";base64,")
+        image_format = format_prefix.split("/")[1].lower()
+        if DEBUG >= 2: print(f"{image_str=} {image_format=}")
+        imgdata = base64.b64decode(base64_data)
+        img = Image.open(BytesIO(imgdata))
+
+        # Convert to RGB if not already
+        if img.mode != "RGB":
+            img = img.convert("RGB")
+
+        return img
+    else:
+        raise ValueError("Invalid image_str format. Must be a URL or a base64 encoded image.")