sharded_inference_engine.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. import numpy as np
  2. import mlx.core as mx
  3. from ..inference_engine import InferenceEngine
  4. from .sharded_model import StatefulShardedModel
  5. from .sharded_utils import load_shard
  6. from ..shard import Shard
  7. from typing import Optional
  8. class MLXFixedShardInferenceEngine(InferenceEngine):
  9. def __init__(self, model_path: str, shard: Shard):
  10. self.shard = shard
  11. model_shard, self.tokenizer = load_shard(model_path, shard)
  12. self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
  13. async def infer_prompt(self, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
  14. if shard != self.shard:
  15. raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
  16. output_data: np.ndarray = np.array(self.stateful_sharded_model.step(mx.array(self.tokenizer.encode(prompt))))
  17. return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
  18. async def infer_tensor(self, shard: Shard, input_data: np.ndarray) -> (np.ndarray, str, bool):
  19. if shard != self.shard:
  20. raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
  21. output_data: np.ndarray = np.array(self.stateful_sharded_model.step(mx.array(input_data)))
  22. return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
  23. async def reset_shard(self, shard: Shard):
  24. if shard != self.shard:
  25. raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
  26. self.stateful_sharded_model.reset()
  27. class MLXDynamicShardInferenceEngine(InferenceEngine):
  28. def __init__(self):
  29. self.shard = None
  30. async def infer_prompt(self, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
  31. await self.ensure_shard(shard)
  32. output_data: np.ndarray = np.array(self.stateful_sharded_model.step(mx.array(self.tokenizer.encode(prompt))))
  33. return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
  34. async def infer_tensor(self, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
  35. await self.ensure_shard(shard)
  36. output_data: np.ndarray = np.array(self.stateful_sharded_model.step(mx.array(input_data)))
  37. return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
  38. async def reset_shard(self, shard: Shard):
  39. await self.ensure_shard(shard)
  40. self.stateful_sharded_model.reset()
  41. async def ensure_shard(self, shard: Shard):
  42. if self.shard == shard:
  43. return
  44. model_shard, self.tokenizer = load_shard(shard.model_id, shard)
  45. self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
  46. self.shard = shard