فهرست منبع

chatgpt api integration

Varshith 9 ماه پیش
والد
کامیت
acc94b50c7

+ 35 - 8
exo/api/chatgpt_api.py

@@ -3,7 +3,7 @@ import time
 import asyncio
 import json
 from pathlib import Path
-from transformers import AutoTokenizer
+from transformers import AutoTokenizer, AutoProcessor
 from typing import List, Literal, Union, Dict
 from aiohttp import web
 import aiohttp_cors
@@ -42,11 +42,15 @@ shard_mappings = {
   "deepseek-coder-v2-lite": {
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27),
   },
+  ### llava
+  "llava-1.5-7b-hf": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32),
+  },
 }
 
 
 class Message:
-  def __init__(self, role: str, content: str):
+  def __init__(self, role: str, content: Union[str, list]):
     self.role = role
     self.content = content
 
@@ -68,6 +72,18 @@ def resolve_tinygrad_tokenizer(model_id: str):
 
 
 async def resolve_tokenizer(model_id: str):
+  try:
+    if DEBUG >= 2: print(f"Trying to AutoProcessor for {model_id}")
+    processor = AutoProcessor.from_pretrained(model_id)
+    processor.eos_token_id = processor.tokenizer.eos_token_id
+    processor.encode = processor.tokenizer.encode
+    return processor
+  except Exception as e:
+    if DEBUG >= 2: print(f"Failed to load processor for {model_id}. Error: {e}")
+    import traceback
+
+    if DEBUG >= 2: print(traceback.format_exc())
+
   try:
     if DEBUG >= 2: print(f"Trying AutoTokenizer for {model_id}")
     return AutoTokenizer.from_pretrained(model_id)
@@ -138,7 +154,18 @@ def generate_completion(
 
 
 def build_prompt(tokenizer, messages: List[Message]):
-  return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+  prompt =  tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+  image_str = None
+  for message in messages:
+    if not isinstance(message.content, list):
+      continue
+
+    for content in message.content:
+      if content.get("type", None) == "image":
+        image_str = content.get("image", None)
+        break
+
+  return prompt, image_str
 
 
 def parse_message(data: dict):
@@ -195,7 +222,7 @@ class ChatGPTAPI:
     shard = shard_mappings.get(data.get("model", "llama-3.1-8b"), {}).get(self.inference_engine_classname)
     messages = [parse_message(msg) for msg in data.get("messages", [])]
     tokenizer = await resolve_tokenizer(shard.model_id)
-    return web.json_response({"length": len(build_prompt(tokenizer, messages))})
+    return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})
 
   async def handle_post_chat_completions(self, request):
     data = await request.json()
@@ -219,13 +246,13 @@ class ChatGPTAPI:
     tokenizer = await resolve_tokenizer(shard.model_id)
     if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
 
-    prompt = build_prompt(tokenizer, chat_request.messages)
+    prompt, image_str = build_prompt(tokenizer, chat_request.messages)
     callback_id = f"chatgpt-api-wait-response-{request_id}"
     callback = self.node.on_token.register(callback_id)
 
-    if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
+    if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=} {image_str=}")
     try:
-      await self.node.process_prompt(shard, prompt, request_id=request_id)
+      await self.node.process_prompt(shard, prompt, image_str, request_id=request_id)
     except Exception as e:
       if DEBUG >= 2:
         import traceback
@@ -294,7 +321,7 @@ class ChatGPTAPI:
         )
 
         finish_reason = "length"
-        eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
+        eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(getattr(tokenizer, "_tokenizer", None), AutoTokenizer) else tokenizer.eos_token_id
         if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
         if tokens[-1] == eos_token_id:
           tokens = tokens[:-1]

+ 10 - 3
exo/inference/mlx/sharded_inference_engine.py

@@ -2,7 +2,7 @@ import numpy as np
 import mlx.core as mx
 from ..inference_engine import InferenceEngine
 from .sharded_model import StatefulShardedModel
-from .sharded_utils import load_shard
+from .sharded_utils import load_shard, get_image_from_str
 from ..shard import Shard
 from typing import Optional
 
@@ -11,9 +11,16 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
   def __init__(self):
     self.shard = None
 
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
+  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     await self.ensure_shard(shard)
-    output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, mx.array(self.tokenizer.encode(prompt))))
+    if image_str:
+      image = get_image_from_str(image_str)
+      inputs = self.tokenizer(prompt, image, return_tensors="np")
+      pixel_values = mx.array(inputs["pixel_values"])
+      input_ids = mx.array(inputs["input_ids"])
+      output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, input_ids, pixel_values))
+    else:
+      output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, mx.array(self.tokenizer.encode(prompt))))
     return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
 
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):

+ 14 - 0
exo/inference/mlx/sharded_utils.py

@@ -8,6 +8,9 @@ import asyncio
 from functools import partial
 from pathlib import Path
 from typing import Optional, Tuple
+import requests
+from PIL import Image
+from io import BytesIO
 
 import mlx.core as mx
 import mlx.nn as nn
@@ -222,7 +225,18 @@ async def load_shard(
   # TODO: figure out a generic solution
   if model.model_type == "llava":
     processor = AutoProcessor.from_pretrained(model_path)
+    processor.eos_token_id = processor.tokenizer.eos_token_id
+    processor.encode = processor.tokenizer.encode
     return model, processor
   else:
     tokenizer = load_tokenizer(model_path, tokenizer_config)
     return model, tokenizer
+
+def get_image_from_str(image_str: str):
+  if image_str.startswith("http"):
+    response = requests.get(image_str, timeout=10)
+    image = Image.open(BytesIO(response.content)).convert("RGB")
+  else:
+    imgdata = base64.b64decode(image_str)
+    image = Image.open(io.BytesIO(imgdata))
+  return image

+ 5 - 3
exo/inference/mlx/test_sharded_llava.py

@@ -15,9 +15,11 @@ shard_full = Shard("llava", 0, 31, 32)
 shard1 = Shard("llava", 0, 12, 32)
 shard2 = Shard("llava", 13, 31, 32)
 
-full_model_shard, full_processor = asyncio.run(load_shard("llava-hf/llava-1.5-7b-hf", shard=shard_full))
-model_shard1, processor1 = asyncio.run(load_shard("llava-hf/llava-1.5-7b-hf", shard=shard1))
-model_shard2, processor2 = asyncio.run(load_shard("llava-hf/llava-1.5-7b-hf", shard=shard2))
+model_path = "llava-hf/llava-1.5-7b-hf"
+
+full_model_shard, full_processor = asyncio.run(load_shard(model_path, shard=shard_full))
+model_shard1, processor1 = asyncio.run(load_shard(model_path, shard=shard1))
+model_shard2, processor2 = asyncio.run(load_shard(model_path, shard=shard2))
 
 full = StatefulShardedModel(shard_full, full_model_shard)
 m1 = StatefulShardedModel(shard1, model_shard1)

+ 2 - 1
exo/networking/grpc/grpc_peer_handle.py

@@ -39,9 +39,10 @@ class GRPCPeerHandle(PeerHandle):
     self.channel = None
     self.stub = None
 
-  async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
+  async def send_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
     request = node_service_pb2.PromptRequest(
       prompt=prompt,
+      image_str=image_str,
       shard=node_service_pb2.Shard(
         model_id=shard.model_id,
         start_layer=shard.start_layer,

+ 3 - 2
exo/networking/grpc/grpc_server.py

@@ -45,9 +45,10 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
       n_layers=request.shard.n_layers,
     )
     prompt = request.prompt
+    image_str = request.image_str
     request_id = request.request_id
-    result = await self.node.process_prompt(shard, prompt, request_id)
-    if DEBUG >= 2: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
+    result = await self.node.process_prompt(shard, prompt, image_str, request_id)
+    if DEBUG >= 2: print(f"SendPrompt {shard=} {prompt=} {image=} {request_id=} result: {result}")
     tensor_data = result.tobytes() if result is not None else None
     return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
 

+ 3 - 2
exo/networking/grpc/node_service.proto

@@ -21,8 +21,9 @@ message Shard {
 message PromptRequest {
   Shard shard = 1;
   string prompt = 2;
-  optional string request_id = 3;
-  optional string inference_state = 4;
+  optional string image_str = 3;
+  optional string request_id = 4;
+  optional string inference_state = 5;
 }
 
 message TensorRequest {

تفاوت فایلی نمایش داده نمی شود زیرا این فایل بسیار بزرگ است
+ 0 - 0
exo/networking/grpc/node_service_pb2.py


+ 1 - 1
exo/networking/peer_handle.py

@@ -28,7 +28,7 @@ class PeerHandle(ABC):
     pass
 
   @abstractmethod
-  async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
+  async def send_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
     pass
 
   @abstractmethod

+ 1 - 1
exo/orchestration/node.py

@@ -16,7 +16,7 @@ class Node(ABC):
     pass
 
   @abstractmethod
-  async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+  async def process_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
     pass
 
   @abstractmethod

+ 12 - 9
exo/orchestration/standard_node.py

@@ -69,7 +69,7 @@ class StandardNode(Node):
     await self.discovery.stop()
     await self.server.stop()
 
-  async def process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+  async def process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
     shard = self.get_current_shard(base_shard)
     asyncio.create_task(
       self.broadcast_opaque_status(
@@ -82,6 +82,7 @@ class StandardNode(Node):
             "base_shard": base_shard.to_dict(),
             "shard": shard.to_dict(),
             "prompt": prompt,
+            "image_str": image_str,
             "inference_state": inference_state,
             "request_id": request_id,
           }
@@ -89,7 +90,7 @@ class StandardNode(Node):
       )
     )
     start_time = time.perf_counter_ns()
-    resp = await self._process_prompt(base_shard, prompt, request_id, inference_state)
+    resp = await self._process_prompt(base_shard, prompt, image_str, request_id, inference_state)
     end_time = time.perf_counter_ns()
     elapsed_time_ns = end_time - start_time
     asyncio.create_task(
@@ -103,6 +104,7 @@ class StandardNode(Node):
             "base_shard": base_shard.to_dict(),
             "shard": shard.to_dict(),
             "prompt": prompt,
+            "image_str": image_str,
             "inference_state": inference_state,
             "request_id": request_id,
             "elapsed_time_ns": elapsed_time_ns,
@@ -113,20 +115,20 @@ class StandardNode(Node):
     )
     return resp
 
-  async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+  async def _process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
     if request_id is None:
       request_id = str(uuid.uuid4())
     if request_id not in self.buffered_token_output:
       self.buffered_token_output[request_id] = ([], False)
     shard = self.get_current_shard(base_shard)
 
-    if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
+    if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=} {image_str=}")
     if shard.start_layer != 0:
-      if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
-      await self.forward_to_next_shard(shard, prompt, request_id)
+      if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=} {image_str=}")
+      await self.forward_to_next_shard(shard, prompt, request_id, image_str)
       return
 
-    result, inference_state, is_finished = await self.inference_engine.infer_prompt(request_id, shard, prompt, inference_state=inference_state)
+    result, inference_state, is_finished = await self.inference_engine.infer_prompt(request_id, shard, prompt, image_str, inference_state=inference_state)
     is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
     if is_finished:
       self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
@@ -234,6 +236,7 @@ class StandardNode(Node):
     base_shard: Shard,
     tensor_or_prompt: Union[np.ndarray, str],
     request_id: str,
+    image_str: Optional[str] = None,
     inference_state: Optional[str] = None,
   ) -> None:
     if not self.partitioning_strategy:
@@ -255,7 +258,7 @@ class StandardNode(Node):
         if isinstance(tensor_or_prompt, np.ndarray):
           await self.process_tensor(shard, tensor_or_prompt, request_id, inference_state=inference_state)
         else:
-          await self.process_prompt(shard, tensor_or_prompt, request_id, inference_state=inference_state)
+          await self.process_prompt(shard, tensor_or_prompt, image_str, request_id, inference_state=inference_state)
         return
 
       target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None)
@@ -267,7 +270,7 @@ class StandardNode(Node):
       if isinstance(tensor_or_prompt, np.ndarray):
         await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state)
       else:
-        await target_peer.send_prompt(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state)
+        await target_peer.send_prompt(next_shard, tensor_or_prompt, image_str=image_str, request_id=request_id, inference_state=inference_state)
 
   def get_current_shard(self, base_shard: Shard) -> Shard:
     partitions = self.partitioning_strategy.partition(self.topology)

+ 1 - 1
setup.py

@@ -22,7 +22,7 @@ install_requires = [
     "tiktoken==0.7.0",
     "tokenizers==0.19.1",
     "tqdm==4.66.4",
-    "transformers==4.41.2",
+    "transformers==4.43.3",
     "uuid==1.30",
     "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@a9f5a764dc640a5e5cbaaeeee21df7c8ca37da38",
 ]

برخی فایل ها در این مقایسه diff نمایش داده نمی شوند زیرا تعداد فایل ها بسیار زیاد است