|
@@ -0,0 +1,121 @@
|
|
|
+import unittest
|
|
|
+from exo.models import get_supported_models, model_cards
|
|
|
+from exo.inference.inference_engine import inference_engine_classes
|
|
|
+from typing import NamedTuple
|
|
|
+
|
|
|
+class TestCase(NamedTuple):
|
|
|
+ name: str
|
|
|
+ engine_lists: list # Will contain short names, will be mapped to class names
|
|
|
+ expected_models_contains: list
|
|
|
+ min_count: int | None
|
|
|
+ exact_count: int | None
|
|
|
+ max_count: int | None
|
|
|
+
|
|
|
+# Helper function to map short names to class names
|
|
|
+def expand_engine_lists(engine_lists):
|
|
|
+ def map_engine(engine):
|
|
|
+ return inference_engine_classes.get(engine, engine) # Return original name if not found
|
|
|
+
|
|
|
+ return [[map_engine(engine) for engine in sublist]
|
|
|
+ for sublist in engine_lists]
|
|
|
+
|
|
|
+test_cases = [
|
|
|
+ TestCase(
|
|
|
+ name="single_mlx_engine",
|
|
|
+ engine_lists=[["mlx"]],
|
|
|
+ expected_models_contains=["llama-3.2-1b", "llama-3.1-70b", "mistral-nemo"],
|
|
|
+ min_count=10,
|
|
|
+ exact_count=None,
|
|
|
+ max_count=None
|
|
|
+ ),
|
|
|
+ TestCase(
|
|
|
+ name="single_tinygrad_engine",
|
|
|
+ engine_lists=[["tinygrad"]],
|
|
|
+ expected_models_contains=["llama-3.2-1b", "llama-3.2-3b"],
|
|
|
+ min_count=5,
|
|
|
+ exact_count=None,
|
|
|
+ max_count=10
|
|
|
+ ),
|
|
|
+ TestCase(
|
|
|
+ name="multiple_engines_or",
|
|
|
+ engine_lists=[["mlx", "tinygrad"], ["mlx"]],
|
|
|
+ expected_models_contains=["llama-3.2-1b", "llama-3.2-3b", "mistral-nemo"],
|
|
|
+ min_count=10,
|
|
|
+ exact_count=None,
|
|
|
+ max_count=None
|
|
|
+ ),
|
|
|
+ TestCase(
|
|
|
+ name="multiple_engines_all",
|
|
|
+ engine_lists=[["mlx", "tinygrad"], ["mlx", "tinygrad"]],
|
|
|
+ expected_models_contains=["llama-3.2-1b", "llama-3.2-3b", "mistral-nemo"],
|
|
|
+ min_count=10,
|
|
|
+ exact_count=None,
|
|
|
+ max_count=None
|
|
|
+ ),
|
|
|
+ TestCase(
|
|
|
+ name="distinct_engine_lists",
|
|
|
+ engine_lists=[["mlx"], ["tinygrad"]],
|
|
|
+ expected_models_contains=["llama-3.2-1b"],
|
|
|
+ min_count=5,
|
|
|
+ exact_count=None,
|
|
|
+ max_count=10
|
|
|
+ ),
|
|
|
+ TestCase(
|
|
|
+ name="no_engines",
|
|
|
+ engine_lists=[],
|
|
|
+ expected_models_contains=None,
|
|
|
+ min_count=None,
|
|
|
+ exact_count=len(model_cards),
|
|
|
+ max_count=None
|
|
|
+ ),
|
|
|
+ TestCase(
|
|
|
+ name="nonexistent_engine",
|
|
|
+ engine_lists=[["NonexistentEngine"]],
|
|
|
+ expected_models_contains=[],
|
|
|
+ min_count=None,
|
|
|
+ exact_count=0,
|
|
|
+ max_count=None
|
|
|
+ ),
|
|
|
+ TestCase(
|
|
|
+ name="dummy_engine",
|
|
|
+ engine_lists=[["dummy"]],
|
|
|
+ expected_models_contains=["dummy"],
|
|
|
+ min_count=None,
|
|
|
+ exact_count=1,
|
|
|
+ max_count=None
|
|
|
+ ),
|
|
|
+]
|
|
|
+
|
|
|
+class TestModelHelpers(unittest.TestCase):
|
|
|
+ def test_get_supported_models(self):
|
|
|
+ for case in test_cases:
|
|
|
+ with self.subTest(f"{case.name}_short_names"):
|
|
|
+ result = get_supported_models(case.engine_lists)
|
|
|
+ self._verify_results(case, result)
|
|
|
+
|
|
|
+ with self.subTest(f"{case.name}_class_names"):
|
|
|
+ class_name_lists = expand_engine_lists(case.engine_lists)
|
|
|
+ result = get_supported_models(class_name_lists)
|
|
|
+ self._verify_results(case, result)
|
|
|
+
|
|
|
+ def _verify_results(self, case, result):
|
|
|
+ if case.expected_models_contains:
|
|
|
+ for model in case.expected_models_contains:
|
|
|
+ self.assertIn(model, result)
|
|
|
+
|
|
|
+ if case.min_count:
|
|
|
+ self.assertGreater(len(result), case.min_count)
|
|
|
+
|
|
|
+ if case.exact_count is not None:
|
|
|
+ self.assertEqual(len(result), case.exact_count)
|
|
|
+
|
|
|
+ # Special case for distinct lists test
|
|
|
+ if case.name == "distinct_engine_lists":
|
|
|
+ self.assertLess(len(result), 10)
|
|
|
+ self.assertNotIn("mistral-nemo", result)
|
|
|
+
|
|
|
+ if case.max_count:
|
|
|
+ self.assertLess(len(result), case.max_count)
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ unittest.main()
|