sharded_inference_engine.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536
  1. import numpy as np
  2. import mlx.core as mx
  3. from ..inference_engine import InferenceEngine
  4. from .sharded_model import StatefulShardedModel
  5. from .sharded_utils import load_shard
  6. from ..shard import Shard
  7. class MLXFixedShardInferenceEngine(InferenceEngine):
  8. def __init__(self, model_path: str, shard: Shard):
  9. print("initializing fixed shard inference", shard)
  10. self.shard = shard
  11. model_shard, self.tokenizer = load_shard(model_path, shard)
  12. self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
  13. async def infer_prompt(self, shard: Shard, prompt: str) -> np.ndarray:
  14. if shard != self.shard:
  15. raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
  16. output_data = self.stateful_sharded_model.step(mx.array(self.tokenizer.encode(prompt)))
  17. return np.array(output_data)
  18. async def infer_shard(self, shard: Shard, input_data: np.ndarray) -> np.ndarray:
  19. if shard != self.shard:
  20. raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
  21. print("infer_shard", shard, input_data)
  22. output_data = self.stateful_sharded_model.step(mx.array(input_data))
  23. return np.array(output_data)
  24. async def reset_shard(self, shard: Shard):
  25. if shard != self.shard:
  26. raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
  27. print(f"Resetting shard: {shard}")
  28. self.stateful_sharded_model.reset()