sharded_inference_engine.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. import numpy as np
  2. import mlx.core as mx
  3. from ..inference_engine import InferenceEngine
  4. from .sharded_model import StatefulShardedModel
  5. from .sharded_utils import load_shard, get_image_from_str
  6. from ..shard import Shard
  7. from typing import Optional
  8. class MLXDynamicShardInferenceEngine(InferenceEngine):
  9. def __init__(self):
  10. self.shard = None
  11. async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
  12. await self.ensure_shard(shard)
  13. if image_str:
  14. image = await get_image_from_str(image_str)
  15. inputs = self.tokenizer(prompt, image, return_tensors="np")
  16. pixel_values = mx.array(inputs["pixel_values"])
  17. input_ids = mx.array(inputs["input_ids"])
  18. output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, input_ids, pixel_values))
  19. else:
  20. output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, mx.array(self.tokenizer.encode(prompt))))
  21. return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
  22. async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
  23. await self.ensure_shard(shard)
  24. output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, mx.array(input_data)))
  25. return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
  26. async def ensure_shard(self, shard: Shard):
  27. if self.shard == shard:
  28. return
  29. model_shard, self.tokenizer = await load_shard(shard.model_id, shard)
  30. self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
  31. self.shard = shard