test_dummy_inference_engine.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. import pytest
  2. import numpy as np
  3. from exo.inference.dummy_inference_engine import DummyInferenceEngine
  4. from exo.inference.shard import Shard
  5. @pytest.mark.asyncio
  6. async def test_dummy_inference_engine():
  7. # Create a mock shard downloader
  8. class MockShardDownloader:
  9. async def ensure_shard(self, shard):
  10. pass
  11. # Initialize the DummyInferenceEngine
  12. engine = DummyInferenceEngine(MockShardDownloader())
  13. # Create a test shard
  14. shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
  15. # Test infer_prompt
  16. output, state, is_finished = await engine.infer_prompt("test_id", shard, "Test prompt")
  17. assert isinstance(output, np.ndarray), "Output should be a numpy array"
  18. assert output.ndim == 2, "Output should be 2-dimensional"
  19. assert isinstance(state, str), "State should be a string"
  20. assert isinstance(is_finished, bool), "is_finished should be a boolean"
  21. # Test infer_tensor
  22. input_tensor = np.array([[1, 2, 3]])
  23. output, state, is_finished = await engine.infer_tensor("test_id", shard, input_tensor)
  24. assert isinstance(output, np.ndarray), "Output should be a numpy array"
  25. assert output.ndim == 2, "Output should be 2-dimensional"
  26. assert isinstance(state, str), "State should be a string"
  27. assert isinstance(is_finished, bool), "is_finished should be a boolean"
  28. print("All tests passed!")
  29. if __name__ == "__main__":
  30. import asyncio
  31. asyncio.run(test_dummy_inference_engine())