test_inference_engine.py 1.6 KB

123456789101112131415161718192021222324252627282930
  1. from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
  2. from exo.inference.inference_engine import InferenceEngine
  3. from exo.inference.shard import Shard
  4. from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
  5. import numpy as np
  6. # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
  7. async def test_inference_engine(inference_engine: InferenceEngine, model_id: str):
  8. prompt = "In a single word only, what is the capital of Japan? "
  9. resp_full, inference_state_full, _ = await inference_engine.infer_prompt(shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
  10. await inference_engine.reset_shard(shard=Shard(model_id=model_id, start_layer=0, end_layer=10, n_layers=32))
  11. resp1, inference_state, _ = await inference_engine.infer_prompt(shard=Shard(model_id=model_id, start_layer=0, end_layer=10, n_layers=32), prompt=prompt)
  12. await inference_engine.reset_shard(shard=Shard(model_id=model_id, start_layer=11, end_layer=31, n_layers=32))
  13. resp2, _, _ = await inference_engine.infer_tensor(shard=Shard(model_id=model_id, start_layer=11, end_layer=31, n_layers=32), input_data=resp1, inference_state=inference_state)
  14. assert np.array_equal(resp_full, resp2)
  15. import asyncio
  16. asyncio.run(test_inference_engine(
  17. MLXDynamicShardInferenceEngine(),
  18. "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
  19. ))
  20. asyncio.run(test_inference_engine(
  21. TinygradDynamicShardInferenceEngine(),
  22. "/Users/alex/Library/Caches/tinygrad/downloads/llama3-8b-sfr",
  23. ))