test_inference_engine.py 1.0 KB

123456789101112131415161718192021
  1. from inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
  2. from inference.inference_engine import InferenceEngine
  3. from inference.shard import Shard
  4. import numpy as np
  5. # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
  6. async def test_inference_engine(inference_engine: InferenceEngine, model_id: str, input_data: np.array):
  7. resp_full, _ = await inference_engine.infer_tensor(shard=Shard(model_id=model_id, start_layer=0, end_layer=1, n_layers=2), input_data=input_data)
  8. resp1, _ = await inference_engine.infer_tensor(shard=Shard(model_id=model_id, start_layer=0, end_layer=0, n_layers=2), input_data=input_data)
  9. resp2, _ = await inference_engine.infer_tensor(shard=Shard(model_id=model_id, start_layer=1, end_layer=1, n_layers=2), input_data=resp1)
  10. assert np.array_equal(resp_full, resp2)
  11. import asyncio
  12. asyncio.run(test_inference_engine(
  13. MLXDynamicShardInferenceEngine(),
  14. "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
  15. [1234]
  16. ))