sharded_inference_engine.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. import numpy as np
  2. import mlx.core as mx
  3. from ..inference_engine import InferenceEngine
  4. from .sharded_model import StatefulShardedModel
  5. from .sharded_utils import load_shard, get_image_from_str
  6. from ..shard import Shard
  7. from typing import Optional
  8. from exo.download.shard_download import ShardDownloader
  9. import asyncio
  10. from concurrent.futures import ThreadPoolExecutor
  11. from functools import partial
  12. class MLXDynamicShardInferenceEngine(InferenceEngine):
  13. def __init__(self, shard_downloader: ShardDownloader):
  14. self.shard = None
  15. self.shard_downloader = shard_downloader
  16. self.executor = ThreadPoolExecutor(max_workers=1)
  17. async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
  18. await self.ensure_shard(shard)
  19. loop = asyncio.get_running_loop()
  20. if image_str:
  21. image = await get_image_from_str(image_str)
  22. tokenize = partial(self.tokenizer, prompt, image, return_tensors="np")
  23. inputs = await loop.run_in_executor(self.executor, tokenize)
  24. pixel_values = mx.array(inputs["pixel_values"])
  25. input_ids = mx.array(inputs["input_ids"])
  26. output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids, pixel_values))
  27. else:
  28. input_ids = mx.array(await loop.run_in_executor(self.executor, self.tokenizer.encode, prompt))
  29. output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids))
  30. return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
  31. async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
  32. await self.ensure_shard(shard)
  33. output_data: np.ndarray = np.array(await asyncio.get_running_loop().run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, mx.array(input_data)))
  34. return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
  35. async def ensure_shard(self, shard: Shard):
  36. if self.shard == shard:
  37. return
  38. model_path = await self.shard_downloader.ensure_shard(shard)
  39. if self.shard != shard:
  40. loop = asyncio.get_running_loop()
  41. def load_shard_wrapper():
  42. return asyncio.run(load_shard(model_path, shard))
  43. model_shard, self.tokenizer = await loop.run_in_executor(self.executor, load_shard_wrapper)
  44. self.stateful_sharded_model = await loop.run_in_executor(self.executor, StatefulShardedModel, shard, model_shard)
  45. self.shard = shard