test_model_helpers.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. import unittest
  2. from exo.models import get_supported_models, model_cards
  3. from exo.inference.inference_engine import inference_engine_classes
  4. from typing import NamedTuple
  5. class TestCase(NamedTuple):
  6. name: str
  7. engine_lists: list # Will contain short names, will be mapped to class names
  8. expected_models_contains: list
  9. min_count: int | None
  10. exact_count: int | None
  11. max_count: int | None
  12. # Helper function to map short names to class names
  13. def expand_engine_lists(engine_lists):
  14. def map_engine(engine):
  15. return inference_engine_classes.get(engine, engine) # Return original name if not found
  16. return [[map_engine(engine) for engine in sublist]
  17. for sublist in engine_lists]
  18. test_cases = [
  19. TestCase(
  20. name="single_mlx_engine",
  21. engine_lists=[["mlx"]],
  22. expected_models_contains=["llama-3.2-1b", "llama-3.1-70b", "mistral-nemo"],
  23. min_count=10,
  24. exact_count=None,
  25. max_count=None
  26. ),
  27. TestCase(
  28. name="single_tinygrad_engine",
  29. engine_lists=[["tinygrad"]],
  30. expected_models_contains=["llama-3.2-1b", "llama-3.2-3b"],
  31. min_count=5,
  32. exact_count=None,
  33. max_count=10
  34. ),
  35. TestCase(
  36. name="multiple_engines_or",
  37. engine_lists=[["mlx", "tinygrad"], ["mlx"]],
  38. expected_models_contains=["llama-3.2-1b", "llama-3.2-3b", "mistral-nemo"],
  39. min_count=10,
  40. exact_count=None,
  41. max_count=None
  42. ),
  43. TestCase(
  44. name="multiple_engines_all",
  45. engine_lists=[["mlx", "tinygrad"], ["mlx", "tinygrad"]],
  46. expected_models_contains=["llama-3.2-1b", "llama-3.2-3b", "mistral-nemo"],
  47. min_count=10,
  48. exact_count=None,
  49. max_count=None
  50. ),
  51. TestCase(
  52. name="distinct_engine_lists",
  53. engine_lists=[["mlx"], ["tinygrad"]],
  54. expected_models_contains=["llama-3.2-1b"],
  55. min_count=5,
  56. exact_count=None,
  57. max_count=10
  58. ),
  59. TestCase(
  60. name="no_engines",
  61. engine_lists=[],
  62. expected_models_contains=None,
  63. min_count=None,
  64. exact_count=len(model_cards),
  65. max_count=None
  66. ),
  67. TestCase(
  68. name="nonexistent_engine",
  69. engine_lists=[["NonexistentEngine"]],
  70. expected_models_contains=[],
  71. min_count=None,
  72. exact_count=0,
  73. max_count=None
  74. ),
  75. TestCase(
  76. name="dummy_engine",
  77. engine_lists=[["dummy"]],
  78. expected_models_contains=["dummy"],
  79. min_count=None,
  80. exact_count=1,
  81. max_count=None
  82. ),
  83. ]
  84. class TestModelHelpers(unittest.TestCase):
  85. def test_get_supported_models(self):
  86. for case in test_cases:
  87. with self.subTest(f"{case.name}_short_names"):
  88. result = get_supported_models(case.engine_lists)
  89. self._verify_results(case, result)
  90. with self.subTest(f"{case.name}_class_names"):
  91. class_name_lists = expand_engine_lists(case.engine_lists)
  92. result = get_supported_models(class_name_lists)
  93. self._verify_results(case, result)
  94. def _verify_results(self, case, result):
  95. if case.expected_models_contains:
  96. for model in case.expected_models_contains:
  97. self.assertIn(model, result)
  98. if case.min_count:
  99. self.assertGreater(len(result), case.min_count)
  100. if case.exact_count is not None:
  101. self.assertEqual(len(result), case.exact_count)
  102. # Special case for distinct lists test
  103. if case.name == "distinct_engine_lists":
  104. self.assertLess(len(result), 10)
  105. self.assertNotIn("mistral-nemo", result)
  106. if case.max_count:
  107. self.assertLess(len(result), case.max_count)
  108. if __name__ == '__main__':
  109. unittest.main()