Browse Source

fix image api prompt encoding

Alex Cheema 9 months ago
parent
commit
178fb75c84
3 changed files with 85 additions and 41 deletions
  1. 1 1
      README.md
  2. 36 14
      exo/api/chatgpt_api.py
  3. 48 26
      tinychat/examples/tinychat/index.js

+ 1 - 1
README.md

@@ -117,7 +117,7 @@ For developers, exo also starts a ChatGPT-compatible API endpoint on http://loca
 curl http://localhost:8000/v1/chat/completions \
   -H "Content-Type: application/json" \
   -d '{
-     "model": "llama-3-8b",
+     "model": "llama-3.1-8b",
      "messages": [{"role": "user", "content": "What is the meaning of exo?"}],
      "temperature": 0.7
    }'

+ 36 - 14
exo/api/chatgpt_api.py

@@ -50,16 +50,29 @@ shard_mappings = {
 
 
 class Message:
-  def __init__(self, role: str, content: Union[str, list]):
-    self.role = role
-    self.content = content
+    def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
+        self.role = role
+        self.content = content
+
+    def to_dict(self):
+        return {
+            "role": self.role,
+            "content": self.content
+        }
 
 
 class ChatCompletionRequest:
-  def __init__(self, model: str, messages: List[Message], temperature: float):
-    self.model = model
-    self.messages = messages
-    self.temperature = temperature
+    def __init__(self, model: str, messages: List[Message], temperature: float):
+        self.model = model
+        self.messages = messages
+        self.temperature = temperature
+
+    def to_dict(self):
+        return {
+            "model": self.model,
+            "messages": [message.to_dict() for message in self.messages],
+            "temperature": self.temperature
+        }
 
 
 def resolve_tinygrad_tokenizer(model_id: str):
@@ -75,8 +88,12 @@ async def resolve_tokenizer(model_id: str):
   try:
     if DEBUG >= 2: print(f"Trying AutoProcessor for {model_id}")
     processor = AutoProcessor.from_pretrained(model_id, use_fast=False)
-    processor.eos_token_id = processor.tokenizer.eos_token_id
-    processor.encode = processor.tokenizer.encode
+    if not hasattr(processor, 'eos_token_id'):
+      processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
+    if not hasattr(processor, 'encode'):
+      processor.encode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).encode
+    if not hasattr(processor, 'decode'):
+      processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
     return processor
   except Exception as e:
     if DEBUG >= 2: print(f"Failed to load processor for {model_id}. Error: {e}")
@@ -157,6 +174,10 @@ def remap_messages(messages: List[Message]) -> List[Message]:
     remapped_messages = []
     last_image = None
     for message in messages:
+        if not isinstance(message.content, list):
+           remapped_messages.append(message)
+           continue
+
         remapped_content = []
         for content in message.content:
             if isinstance(content, dict):
@@ -168,16 +189,17 @@ def remap_messages(messages: List[Message]) -> List[Message]:
                 else:
                     remapped_content.append(content)
             else:
-                remapped_content.append({"type": "text", "text": content})
+                remapped_content.append(content)
         remapped_messages.append(Message(role=message.role, content=remapped_content))
 
     if last_image:
         # Replace the last image placeholder with the actual image content
         for message in reversed(remapped_messages):
             for i, content in enumerate(message.content):
-                if content.get("type") == "text" and content.get("text") == "[An image was uploaded but is not displayed here]":
-                    message.content[i] = last_image
-                    return remapped_messages
+                if isinstance(content, dict):
+                  if content.get("type") == "text" and content.get("text") == "[An image was uploaded but is not displayed here]":
+                      message.content[i] = last_image
+                      return remapped_messages
 
     return remapped_messages
 
@@ -192,7 +214,7 @@ def build_prompt(tokenizer, _messages: List[Message]):
     for content in message.content:
       # note: we only support one image at a time right now. Multiple is possible. See: https://github.com/huggingface/transformers/blob/e68ec18ce224af879f22d904c7505a765fb77de3/docs/source/en/model_doc/llava.md?plain=1#L41
       # follows the convention in https://platform.openai.com/docs/guides/vision
-      if content.get("type", None) == "image":
+      if isinstance(content, dict) and content.get("type", None) == "image":
         image_str = content.get("image", None)
         break
 

+ 48 - 26
tinychat/examples/tinychat/index.js

@@ -79,7 +79,7 @@ document.addEventListener("alpine:init", () => {
       this.tokens_per_second = 0;
 
       // prepare messages for API request
-      const apiMessages = this.cstate.messages.map(msg => {
+      let apiMessages = this.cstate.messages.map(msg => {
         if (msg.content.startsWith('![Uploaded Image]')) {
           return {
             role: "user",
@@ -89,36 +89,40 @@ document.addEventListener("alpine:init", () => {
                 image_url: {
                   url: this.imageUrl
                 }
+              },
+              {
+                type: "text",
+                text: value // Use the actual text the user typed
               }
             ]
           };
         } else {
           return {
             role: msg.role,
-            content: [
-              {
-                type: "text",
-                text: msg.content
-              }
-            ]
+            content: msg.content
           };
         }
       });
-
-      // If there's an image URL, add it to all messages
-      if (this.imageUrl) {
-        apiMessages.forEach(msg => {
-          if (!msg.content.some(content => content.type === "image_url")) {
-            msg.content.push({
-              type: "image_url",
-              image_url: {
-                url: this.imageUrl
-              }
-            });
+      const containsImage = apiMessages.some(msg => Array.isArray(msg.content) && msg.content.some(item => item.type === 'image_url'));
+      if (containsImage) {
+        // Map all messages with string content to object with type text
+        apiMessages = apiMessages.map(msg => {
+          if (typeof msg.content === 'string') {
+            return {
+              ...msg,
+              content: [
+                {
+                  type: "text",
+                  text: msg.content
+                }
+              ]
+            };
           }
+          return msg;
         });
       }
 
+
       // start receiving server sent events
       let gottenFirstChunk = false;
       for await (
@@ -146,19 +150,37 @@ document.addEventListener("alpine:init", () => {
         }
       }
 
-      // update the state in histories or add it if it doesn't exist
-      const index = this.histories.findIndex((cstate) => {
-        return cstate.time === this.cstate.time;
+      // Clean the cstate before adding it to histories
+      const cleanedCstate = JSON.parse(JSON.stringify(this.cstate));
+      cleanedCstate.messages = cleanedCstate.messages.map(msg => {
+        if (Array.isArray(msg.content)) {
+          return {
+            ...msg,
+            content: msg.content.map(item =>
+              item.type === 'image_url' ? { type: 'image_url', image_url: { url: '[IMAGE_PLACEHOLDER]' } } : item
+            )
+          };
+        }
+        return msg;
       });
-      this.cstate.time = Date.now();
+
+      // Update the state in histories or add it if it doesn't exist
+      const index = this.histories.findIndex((cstate) => cstate.time === cleanedCstate.time);
+      cleanedCstate.time = Date.now();
       if (index !== -1) {
-        // update the time
-        this.histories[index] = this.cstate;
+        // Update the existing entry
+        this.histories[index] = cleanedCstate;
       } else {
-        this.histories.push(this.cstate);
+        // Add a new entry
+        this.histories.push(cleanedCstate);
       }
+      console.log(this.histories)
       // update in local storage
-      localStorage.setItem("histories", JSON.stringify(this.histories));
+      try {
+        localStorage.setItem("histories", JSON.stringify(this.histories));
+      } catch (error) {
+        console.error("Failed to save histories to localStorage:", error);
+      }
 
       this.generating = false;
     },