|
@@ -8,6 +8,7 @@ from typing import Optional
|
|
|
from exo.download.shard_download import ShardDownloader
|
|
|
import asyncio
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
+from functools import partial
|
|
|
|
|
|
class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|
|
def __init__(self, shard_downloader: ShardDownloader):
|
|
@@ -20,7 +21,8 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|
|
loop = asyncio.get_running_loop()
|
|
|
if image_str:
|
|
|
image = await get_image_from_str(image_str)
|
|
|
- inputs = await loop.run_in_executor(self.executor, self.tokenizer, prompt, image, return_tensors="np")
|
|
|
+ tokenize = partial(self.tokenizer, prompt, image, return_tensors="np")
|
|
|
+ inputs = await loop.run_in_executor(self.executor, tokenize)
|
|
|
pixel_values = mx.array(inputs["pixel_values"])
|
|
|
input_ids = mx.array(inputs["input_ids"])
|
|
|
output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids, pixel_values))
|