Prechádzať zdrojové kódy

fix modelpool, add tests in test/test_model_helpers.py

Alex Cheema 5 mesiacov pred
rodič
commit
1b7e67832c
4 zmenil súbory, kde vykonal 148 pridanie a 18 odobranie
  1. 1 0
      .circleci/config.yml
  2. 2 15
      exo/api/chatgpt_api.py
  3. 24 3
      exo/models.py
  4. 121 0
      test/test_model_helpers.py

+ 1 - 0
.circleci/config.yml

@@ -126,6 +126,7 @@ jobs:
             METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 METAL_XCODE=1 TEMPERATURE=0 python3 -m exo.inference.test_inference_engine
             echo "Running tokenizer tests..."
             python3 ./test/test_tokenizers.py
+            python3 ./test/test_model_helpers.py
 
   discovery_integration_test:
     macos:

+ 2 - 15
exo/api/chatgpt_api.py

@@ -13,11 +13,9 @@ import sys
 from exo import DEBUG, VERSION
 from exo.download.download_progress import RepoProgressEvent
 from exo.helpers import PrefixDict, shutdown
-from exo.inference.inference_engine import inference_engine_classes
-from exo.inference.shard import Shard
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration import Node
-from exo.models import build_base_shard, model_cards, get_repo, pretty_name
+from exo.models import build_base_shard, model_cards, get_repo, pretty_name, get_supported_models
 from typing import Callable, Optional
 
 class Message:
@@ -218,18 +216,7 @@ class ChatGPTAPI:
     return web.json_response({
       "model pool": {
         model_name: pretty_name.get(model_name, model_name) 
-        for model_name in [
-          model_id for model_id, model_info in model_cards.items() 
-          if all(map(
-            lambda engine: engine in model_info["repo"],
-            list(dict.fromkeys([
-              inference_engine_classes.get(engine_name, None) 
-              for engine_list in self.node.topology_inference_engines_pool 
-              for engine_name in engine_list 
-              if engine_name is not None
-            ] + [self.inference_engine_classname]))
-          ))
-        ]
+        for model_name in get_supported_models(self.node.topology_inference_engines_pool)
       }
     })
   

+ 24 - 3
exo/models.py

@@ -1,11 +1,11 @@
 from exo.inference.shard import Shard
-from typing import Optional
+from typing import Optional, List
 
 model_cards = {
   ### llama
   "llama-3.2-1b": {
     "layers": 16,
-    "repo": { 
+    "repo": {
       "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-1B-Instruct-4bit",
       "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-1B-Instruct",
     },
@@ -124,4 +124,25 @@ def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional
   if repo is None or n_layers < 1:
     return None
   return Shard(model_id, 0, 0, n_layers)
-  
+
+def get_supported_models(supported_inference_engine_lists: List[List[str]]) -> List[str]:
+  if not supported_inference_engine_lists:
+    return list(model_cards.keys())
+
+  from exo.inference.inference_engine import inference_engine_classes
+  supported_inference_engine_lists = [
+    [inference_engine_classes[engine] if engine in inference_engine_classes else engine for engine in engine_list]
+    for engine_list in supported_inference_engine_lists
+  ]
+
+  def has_any_engine(model_info: dict, engine_list: List[str]) -> bool:
+    return any(engine in model_info.get("repo", {}) for engine in engine_list)
+
+  def supports_all_engine_lists(model_info: dict) -> bool:
+    return all(has_any_engine(model_info, engine_list)
+              for engine_list in supported_inference_engine_lists)
+
+  return [
+    model_id for model_id, model_info in model_cards.items()
+    if supports_all_engine_lists(model_info)
+  ]

+ 121 - 0
test/test_model_helpers.py

@@ -0,0 +1,121 @@
+import unittest
+from exo.models import get_supported_models, model_cards
+from exo.inference.inference_engine import inference_engine_classes
+from typing import NamedTuple
+
+class TestCase(NamedTuple):
+  name: str
+  engine_lists: list  # Will contain short names, will be mapped to class names
+  expected_models_contains: list
+  min_count: int | None
+  exact_count: int | None
+  max_count: int | None
+
+# Helper function to map short names to class names
+def expand_engine_lists(engine_lists):
+  def map_engine(engine):
+    return inference_engine_classes.get(engine, engine)  # Return original name if not found
+
+  return [[map_engine(engine) for engine in sublist]
+          for sublist in engine_lists]
+
+test_cases = [
+  TestCase(
+    name="single_mlx_engine",
+    engine_lists=[["mlx"]],
+    expected_models_contains=["llama-3.2-1b", "llama-3.1-70b", "mistral-nemo"],
+    min_count=10,
+    exact_count=None,
+    max_count=None
+  ),
+  TestCase(
+    name="single_tinygrad_engine",
+    engine_lists=[["tinygrad"]],
+    expected_models_contains=["llama-3.2-1b", "llama-3.2-3b"],
+    min_count=5,
+    exact_count=None,
+    max_count=10
+  ),
+  TestCase(
+    name="multiple_engines_or",
+    engine_lists=[["mlx", "tinygrad"], ["mlx"]],
+    expected_models_contains=["llama-3.2-1b", "llama-3.2-3b", "mistral-nemo"],
+    min_count=10,
+    exact_count=None,
+    max_count=None
+  ),
+  TestCase(
+    name="multiple_engines_all",
+    engine_lists=[["mlx", "tinygrad"], ["mlx", "tinygrad"]],
+    expected_models_contains=["llama-3.2-1b", "llama-3.2-3b", "mistral-nemo"],
+    min_count=10,
+    exact_count=None,
+    max_count=None
+  ),
+  TestCase(
+    name="distinct_engine_lists",
+    engine_lists=[["mlx"], ["tinygrad"]],
+    expected_models_contains=["llama-3.2-1b"],
+    min_count=5,
+    exact_count=None,
+    max_count=10
+  ),
+  TestCase(
+    name="no_engines",
+    engine_lists=[],
+    expected_models_contains=None,
+    min_count=None,
+    exact_count=len(model_cards),
+    max_count=None
+  ),
+  TestCase(
+    name="nonexistent_engine",
+    engine_lists=[["NonexistentEngine"]],
+    expected_models_contains=[],
+    min_count=None,
+    exact_count=0,
+    max_count=None
+  ),
+  TestCase(
+    name="dummy_engine",
+    engine_lists=[["dummy"]],
+    expected_models_contains=["dummy"],
+    min_count=None,
+    exact_count=1,
+    max_count=None
+  ),
+]
+
+class TestModelHelpers(unittest.TestCase):
+  def test_get_supported_models(self):
+    for case in test_cases:
+      with self.subTest(f"{case.name}_short_names"):
+        result = get_supported_models(case.engine_lists)
+        self._verify_results(case, result)
+
+      with self.subTest(f"{case.name}_class_names"):
+        class_name_lists = expand_engine_lists(case.engine_lists)
+        result = get_supported_models(class_name_lists)
+        self._verify_results(case, result)
+
+  def _verify_results(self, case, result):
+    if case.expected_models_contains:
+      for model in case.expected_models_contains:
+        self.assertIn(model, result)
+
+    if case.min_count:
+      self.assertGreater(len(result), case.min_count)
+
+    if case.exact_count is not None:
+      self.assertEqual(len(result), case.exact_count)
+
+    # Special case for distinct lists test
+    if case.name == "distinct_engine_lists":
+      self.assertLess(len(result), 10)
+      self.assertNotIn("mistral-nemo", result)
+
+    if case.max_count:
+      self.assertLess(len(result), case.max_count)
+
+if __name__ == '__main__':
+  unittest.main()