test_inference_engine.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334
  1. from inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
  2. from inference.inference_engine import InferenceEngine
  3. from inference.shard import Shard
  4. from inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
  5. import numpy as np
  6. # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
  7. async def test_inference_engine(inference_engine: InferenceEngine, model_id: str, input_data: np.array):
  8. # inference_engine.reset_shard(Shard("", 0,0,0))
  9. resp_full, _ = await inference_engine.infer_prompt(shard=Shard(model_id=model_id, start_layer=0, end_layer=1, n_layers=2), prompt="In one word, what is the capital of USA? ")
  10. print("resp_full", resp_full)
  11. print("decoded", inference_engine.tokenizer.decode(resp_full))
  12. # inference_engine.reset_shard(Shard("", 0,0,0))
  13. # resp1, _ = await inference_engine.infer_tensor(shard=Shard(model_id=model_id, start_layer=0, end_layer=0, n_layers=2), input_data=input_data)
  14. # resp2, _ = await inference_engine.infer_tensor(shard=Shard(model_id=model_id, start_layer=1, end_layer=1, n_layers=2), input_data=resp1)
  15. # assert np.array_equal(resp_full, resp2)
  16. import asyncio
  17. # asyncio.run(test_inference_engine(
  18. # MLXDynamicShardInferenceEngine(),
  19. # "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
  20. # [1234]
  21. # ))
  22. asyncio.run(test_inference_engine(
  23. TinygradDynamicShardInferenceEngine(),
  24. "/Users/alex/Library/Caches/tinygrad/downloads/llama3-8b-sfr",
  25. [1234]
  26. ))