sharded_inference_engine.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import numpy as np
  2. import mlx.core as mx
  3. from ..inference_engine import InferenceEngine
  4. from .sharded_model import StatefulShardedModel
  5. from .sharded_utils import load_shard
  6. from ..shard import Shard
  7. class MLXFixedShardInferenceEngine(InferenceEngine):
  8. def __init__(self, model_path: str, shard: Shard):
  9. print("initializing fixed shard inference", shard)
  10. self.shard = shard
  11. model_shard, self.tokenizer = load_shard(model_path, shard)
  12. self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
  13. async def infer_prompt(self, shard: Shard, prompt: str) -> (np.ndarray, bool):
  14. if shard != self.shard:
  15. raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
  16. output_data: np.ndarray = np.array(self.stateful_sharded_model.step(mx.array(self.tokenizer.encode(prompt))))
  17. print(f"output_data size: {output_data.size}, output_data: {output_data}")
  18. return output_data, output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
  19. async def infer_tensor(self, shard: Shard, input_data: np.ndarray) -> (np.ndarray, bool):
  20. if shard != self.shard:
  21. raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
  22. output_data: np.ndarray = np.array(self.stateful_sharded_model.step(mx.array(input_data)))
  23. return output_data, output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
  24. async def reset_shard(self, shard: Shard):
  25. if shard != self.shard:
  26. raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
  27. print(f"Resetting shard: {shard}")
  28. self.stateful_sharded_model.reset()
  29. class MLXDynamicShardInferenceEngine(InferenceEngine):
  30. def __init__(self):
  31. self.shard = None
  32. async def infer_prompt(self, shard: Shard, prompt: str) -> (np.ndarray, bool):
  33. await self.ensure_shard(shard)
  34. output_data: np.ndarray = np.array(self.stateful_sharded_model.step(mx.array(self.tokenizer.encode(prompt))))
  35. return output_data, output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
  36. async def infer_tensor(self, shard: Shard, input_data: np.ndarray) -> (np.ndarray, bool):
  37. await self.ensure_shard(shard)
  38. output_data: np.ndarray = np.array(self.stateful_sharded_model.step(mx.array(input_data)))
  39. return output_data, output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
  40. async def reset_shard(self, shard: Shard):
  41. await self.ensure_shard(shard)
  42. print(f"Resetting shard: {shard}")
  43. self.stateful_sharded_model.reset()
  44. async def ensure_shard(self, shard: Shard):
  45. if self.shard == shard:
  46. return
  47. model_shard, self.tokenizer = load_shard(shard.model_id, shard)
  48. self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
  49. self.shard = shard