Browse Source

Merge pull request #329 from varshith15/llava_broken_fix

fix: tokenize in llava pipeline
Alex Cheema 9 months ago
parent
commit
9a93dcc414
1 changed files with 3 additions and 1 deletions
  1. 3 1
      exo/inference/mlx/sharded_inference_engine.py

+ 3 - 1
exo/inference/mlx/sharded_inference_engine.py

@@ -8,6 +8,7 @@ from typing import Optional
 from exo.download.shard_download import ShardDownloader
 import asyncio
 from concurrent.futures import ThreadPoolExecutor
+from functools import partial
 
 class MLXDynamicShardInferenceEngine(InferenceEngine):
   def __init__(self, shard_downloader: ShardDownloader):
@@ -20,7 +21,8 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     loop = asyncio.get_running_loop()
     if image_str:
       image = await get_image_from_str(image_str)
-      inputs = await loop.run_in_executor(self.executor, self.tokenizer, prompt, image, return_tensors="np")
+      tokenize = partial(self.tokenizer, prompt, image, return_tensors="np")
+      inputs = await loop.run_in_executor(self.executor, tokenize)
       pixel_values = mx.array(inputs["pixel_values"])
       input_ids = mx.array(inputs["input_ids"])
       output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids, pixel_values))