1
0
Эх сурвалжийг харах

increase max request size to send raw images, make image download from url async, use chatgpt-compatible convention for images

Alex Cheema 9 сар өмнө
parent
commit
0d45a855fb

+ 4 - 2
README.md

@@ -137,8 +137,10 @@ curl http://localhost:8000/v1/chat/completions \
             "text": "What are these?"
             "text": "What are these?"
           },
           },
           {
           {
-            "type": "image",
-            "image": "http://images.cocodataset.org/val2017/000000039769.jpg"
+            "type": "image_url",
+            "image_url": {
+              "url": "http://images.cocodataset.org/val2017/000000039769.jpg"
+            }
           }
           }
         ]
         ]
       }
       }

+ 35 - 5
exo/api/chatgpt_api.py

@@ -153,14 +153,45 @@ def generate_completion(
   return completion
   return completion
 
 
 
 
-def build_prompt(tokenizer, messages: List[Message]):
-  prompt =  tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+def remap_messages(messages: List[Message]) -> List[Message]:
+    remapped_messages = []
+    last_image = None
+    for message in messages:
+        remapped_content = []
+        for content in message.content:
+            if isinstance(content, dict):
+                if content.get("type") in ["image_url", "image"]:
+                    image_url = content.get("image_url", {}).get("url") or content.get("image")
+                    if image_url:
+                        last_image = {"type": "image", "image": image_url}
+                        remapped_content.append({"type": "text", "text": "[An image was uploaded but is not displayed here]"})
+                else:
+                    remapped_content.append(content)
+            else:
+                remapped_content.append({"type": "text", "text": content})
+        remapped_messages.append(Message(role=message.role, content=remapped_content))
+
+    if last_image:
+        # Replace the last image placeholder with the actual image content
+        for message in reversed(remapped_messages):
+            for i, content in enumerate(message.content):
+                if content.get("type") == "text" and content.get("text") == "[An image was uploaded but is not displayed here]":
+                    message.content[i] = last_image
+                    return remapped_messages
+
+    return remapped_messages
+
+def build_prompt(tokenizer, _messages: List[Message]):
+  messages = remap_messages(_messages)
+  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
   image_str = None
   image_str = None
   for message in messages:
   for message in messages:
     if not isinstance(message.content, list):
     if not isinstance(message.content, list):
       continue
       continue
 
 
     for content in message.content:
     for content in message.content:
+      # note: we only support one image at a time right now. Multiple is possible. See: https://github.com/huggingface/transformers/blob/e68ec18ce224af879f22d904c7505a765fb77de3/docs/source/en/model_doc/llava.md?plain=1#L41
+      # follows the convention in https://platform.openai.com/docs/guides/vision
       if content.get("type", None) == "image":
       if content.get("type", None) == "image":
         image_str = content.get("image", None)
         image_str = content.get("image", None)
         break
         break
@@ -187,7 +218,7 @@ class ChatGPTAPI:
     self.node = node
     self.node = node
     self.inference_engine_classname = inference_engine_classname
     self.inference_engine_classname = inference_engine_classname
     self.response_timeout_secs = response_timeout_secs
     self.response_timeout_secs = response_timeout_secs
-    self.app = web.Application()
+    self.app = web.Application(client_max_size=100 * 1024 * 1024)  # 100MB to support image upload
     self.prev_token_lens: Dict[str, int] = {}
     self.prev_token_lens: Dict[str, int] = {}
     self.stream_tasks: Dict[str, asyncio.Task] = {}
     self.stream_tasks: Dict[str, asyncio.Task] = {}
     cors = aiohttp_cors.setup(self.app)
     cors = aiohttp_cors.setup(self.app)
@@ -214,7 +245,6 @@ class ChatGPTAPI:
     return middleware
     return middleware
 
 
   async def handle_root(self, request):
   async def handle_root(self, request):
-    print(f"Handling root request from {request.remote}")
     return web.FileResponse(self.static_dir / "index.html")
     return web.FileResponse(self.static_dir / "index.html")
 
 
   async def handle_post_chat_token_encode(self, request):
   async def handle_post_chat_token_encode(self, request):
@@ -279,7 +309,7 @@ class ChatGPTAPI:
           self.prev_token_lens[request_id] = max(prev_last_tokens_len, len(tokens))
           self.prev_token_lens[request_id] = max(prev_last_tokens_len, len(tokens))
           new_tokens = tokens[prev_last_tokens_len:]
           new_tokens = tokens[prev_last_tokens_len:]
           finish_reason = None
           finish_reason = None
-          eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
+          eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer, AutoTokenizer) else getattr(tokenizer, "eos_token_id", None)
           if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id:
           if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id:
             new_tokens = new_tokens[:-1]
             new_tokens = new_tokens[:-1]
             if is_finished:
             if is_finished:

+ 1 - 1
exo/inference/mlx/sharded_inference_engine.py

@@ -14,7 +14,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
     if image_str:
     if image_str:
-      image = get_image_from_str(image_str)
+      image = await get_image_from_str(image_str)
       inputs = self.tokenizer(prompt, image, return_tensors="np")
       inputs = self.tokenizer(prompt, image, return_tensors="np")
       pixel_values = mx.array(inputs["pixel_values"])
       pixel_values = mx.array(inputs["pixel_values"])
       input_ids = mx.array(inputs["input_ids"])
       input_ids = mx.array(inputs["input_ids"])

+ 26 - 8
exo/inference/mlx/sharded_utils.py

@@ -5,13 +5,16 @@ import importlib
 import json
 import json
 import logging
 import logging
 import asyncio
 import asyncio
+import aiohttp
 from functools import partial
 from functools import partial
 from pathlib import Path
 from pathlib import Path
 from typing import Optional, Tuple
 from typing import Optional, Tuple
 import requests
 import requests
 from PIL import Image
 from PIL import Image
 from io import BytesIO
 from io import BytesIO
+import base64
 
 
+from exo import DEBUG
 import mlx.core as mx
 import mlx.core as mx
 import mlx.nn as nn
 import mlx.nn as nn
 from huggingface_hub import snapshot_download
 from huggingface_hub import snapshot_download
@@ -232,11 +235,26 @@ async def load_shard(
     tokenizer = load_tokenizer(model_path, tokenizer_config)
     tokenizer = load_tokenizer(model_path, tokenizer_config)
     return model, tokenizer
     return model, tokenizer
 
 
-def get_image_from_str(image_str: str):
-  if image_str.startswith("http"):
-    response = requests.get(image_str, timeout=10)
-    image = Image.open(BytesIO(response.content)).convert("RGB")
-  else:
-    imgdata = base64.b64decode(image_str)
-    image = Image.open(io.BytesIO(imgdata))
-  return image
+async def get_image_from_str(_image_str: str):
+    image_str = _image_str.strip()
+
+    if image_str.startswith("http"):
+        async with aiohttp.ClientSession() as session:
+            async with session.get(image_str, timeout=10) as response:
+                content = await response.read()
+                return Image.open(BytesIO(content)).convert("RGB")
+    elif image_str.startswith("data:image/"):
+        # Extract the image format and base64 data
+        format_prefix, base64_data = image_str.split(";base64,")
+        image_format = format_prefix.split("/")[1].lower()
+        if DEBUG >= 2: print(f"{image_str=} {image_format=}")
+        imgdata = base64.b64decode(base64_data)
+        img = Image.open(BytesIO(imgdata))
+
+        # Convert to RGB if not already
+        if img.mode != "RGB":
+            img = img.convert("RGB")
+
+        return img
+    else:
+        raise ValueError("Invalid image_str format. Must be a URL or a base64 encoded image.")