sharded_inference_engine.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. import mlx.nn as nn
  2. import numpy as np
  3. import mlx.core as mx
  4. from ..inference_engine import InferenceEngine
  5. from .sharded_model import StatefulShardedModel
  6. from .sharded_utils import load_shard
  7. from ..shard import Shard
  8. class MLXFixedShardInferenceEngine(InferenceEngine):
  9. def __init__(self, model_path: str, shard: Shard):
  10. print("initializing fixed shard inference", shard)
  11. self.shard = shard
  12. model_shard, self.tokenizer = load_shard(model_path, shard)
  13. self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
  14. async def infer_prompt(self, shard: Shard, prompt: str) -> np.ndarray:
  15. if shard != self.shard:
  16. raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
  17. output_data = self.stateful_sharded_model.step(mx.array(self.tokenizer.encode(prompt)))
  18. return np.array(output_data)
  19. async def infer_shard(self, shard: Shard, input_data: np.ndarray) -> np.ndarray:
  20. if shard != self.shard:
  21. raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
  22. print("infer_shard", shard, input_data)
  23. output_data = self.stateful_sharded_model.step(mx.array(input_data))
  24. return np.array(output_data)
  25. async def reset_shard(self, shard: Shard):
  26. if shard != self.shard:
  27. raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
  28. print(f"Resetting shard: {shard}")
  29. self.stateful_sharded_model.reset()