test_inference_engine.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
  2. from exo.download.hf.hf_shard_download import HFShardDownloader
  3. from exo.inference.inference_engine import InferenceEngine
  4. from exo.inference.shard import Shard
  5. from exo.helpers import DEBUG
  6. import os
  7. import asyncio
  8. import numpy as np
  9. # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
  10. async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str, n_layers: int):
  11. prompt = "In a single word only, what is the last name of the current president of the USA?"
  12. resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), prompt=prompt)
  13. token_full = await inference_engine_1.sample(resp_full)
  14. token_full = token_full.reshape(1, -1)
  15. next_resp_full = await inference_engine_1.infer_tensor(
  16. "A",
  17. shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers),
  18. input_data=token_full,
  19. )
  20. pp = n_layers // 2
  21. resp1 = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers), prompt=prompt)
  22. resp2 = await inference_engine_2.infer_tensor(
  23. "B",
  24. shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
  25. input_data=resp1,
  26. )
  27. tokens2 = await inference_engine_1.sample(resp2)
  28. tokens2 = tokens2.reshape(1, -1)
  29. resp3 = await inference_engine_1.infer_tensor(
  30. "B",
  31. shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers),
  32. input_data=tokens2,
  33. )
  34. resp4 = await inference_engine_2.infer_tensor(
  35. "B",
  36. shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
  37. input_data=resp3,
  38. )
  39. assert np.array_equal(resp_full, resp2)
  40. assert np.array_equal(next_resp_full, resp4)
  41. asyncio.run(test_inference_engine(MLXDynamicShardInferenceEngine(HFShardDownloader()), MLXDynamicShardInferenceEngine(HFShardDownloader()), "llama-3.2-1b", 16))
  42. if os.getenv("RUN_TINYGRAD", default="0") == "1":
  43. import tinygrad
  44. import os
  45. from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
  46. tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
  47. asyncio.run(
  48. test_inference_engine(TinygradDynamicShardInferenceEngine(HFShardDownloader()), TinygradDynamicShardInferenceEngine(HFShardDownloader()), "llama-3-8b", 32)
  49. )