Browse Source

Merge pull request #433 from blindcrone/intercompatibility

Enabled inference engine intercompatibility
Alex Cheema 7 months ago
parent
commit
7070178de2

+ 30 - 9
exo/api/chatgpt_api.py

@@ -11,10 +11,11 @@ import traceback
 from exo import DEBUG, VERSION
 from exo.download.download_progress import RepoProgressEvent
 from exo.helpers import PrefixDict
+from exo.inference.inference_engine import inference_engine_classes
 from exo.inference.shard import Shard
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration import Node
-from exo.models import model_base_shards
+from exo.models import build_base_shard, model_cards, get_repo, pretty_name
 from typing import Callable
 
 
@@ -171,6 +172,7 @@ class ChatGPTAPI:
     cors.add(self.app.router.add_post("/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
     cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
     cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
+    cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
 
     self.static_dir = Path(__file__).parent.parent/"tinychat"
     self.app.router.add_get("/", self.handle_root)
@@ -198,14 +200,33 @@ class ChatGPTAPI:
   async def handle_root(self, request):
     return web.FileResponse(self.static_dir/"index.html")
 
+  async def handle_model_support(self, request):
+    return web.json_response({
+      "model pool": {
+        model_name: pretty_name.get(model_name, model_name) 
+        for model_name in [
+          model_id for model_id, model_info in model_cards.items() 
+          if all(map(
+            lambda engine: engine in model_info["repo"],
+            list(dict.fromkeys([
+              inference_engine_classes.get(engine_name, None) 
+              for engine_list in self.node.topology_inference_engines_pool 
+              for engine_name in engine_list 
+              if engine_name is not None
+            ] + [self.inference_engine_classname]))
+          ))
+        ]
+      }
+    })
+  
   async def handle_get_models(self, request):
-    return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_base_shards.items()])
+    return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])
 
   async def handle_post_chat_token_encode(self, request):
     data = await request.json()
-    shard = model_base_shards.get(data.get("model", self.default_model), {}).get(self.inference_engine_classname)
+    shard = build_base_shard(self.default_model, self.inference_engine_classname)
     messages = [parse_message(msg) for msg in data.get("messages", [])]
-    tokenizer = await resolve_tokenizer(shard.model_id)
+    tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
     return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})
 
   async def handle_get_download_progress(self, request):
@@ -224,18 +245,18 @@ class ChatGPTAPI:
     chat_request = parse_chat_request(data, self.default_model)
     if chat_request.model and chat_request.model.startswith("gpt-"):  # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
       chat_request.model = self.default_model if self.default_model.startswith("llama") else "llama-3.2-1b"
-    if not chat_request.model or chat_request.model not in model_base_shards:
-      if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_base_shards.keys())}. Defaulting to {self.default_model}")
+    if not chat_request.model or chat_request.model not in model_cards:
+      if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
       chat_request.model = self.default_model
-    shard = model_base_shards[chat_request.model].get(self.inference_engine_classname, None)
+    shard = build_base_shard(chat_request.model, self.inference_engine_classname)
     if not shard:
-      supported_models = [model for model, engines in model_base_shards.items() if self.inference_engine_classname in engines]
+      supported_models = [model for model, info in model_cards.items() if self.inference_engine_classname in info.get("repo", {})]
       return web.json_response(
         {"detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"},
         status=400,
       )
 
-    tokenizer = await resolve_tokenizer(shard.model_id)
+    tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
     if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
 
     prompt = build_prompt(tokenizer, chat_request.messages)

+ 8 - 6
exo/download/hf/hf_shard_download.py

@@ -7,6 +7,7 @@ from exo.download.shard_download import ShardDownloader
 from exo.download.download_progress import RepoProgressEvent
 from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_weight_map, get_allow_patterns, get_repo_root
 from exo.helpers import AsyncCallbackSystem, DEBUG
+from exo.models import model_cards, get_repo
 
 
 class HFShardDownloader(ShardDownloader):
@@ -17,11 +18,12 @@ class HFShardDownloader(ShardDownloader):
     self.completed_downloads: Dict[Shard, Path] = {}
     self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
 
-  async def ensure_shard(self, shard: Shard) -> Path:
+  async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
+    repo_name = get_repo(shard.model_id, inference_engine_name)
     if shard in self.completed_downloads:
       return self.completed_downloads[shard]
     if self.quick_check:
-      repo_root = get_repo_root(shard.model_id)
+      repo_root = get_repo_root(repo_name)
       snapshots_dir = repo_root/"snapshots"
       if snapshots_dir.exists():
         visible_dirs = [d for d in snapshots_dir.iterdir() if not d.name.startswith('.')]
@@ -51,7 +53,7 @@ class HFShardDownloader(ShardDownloader):
     self.active_downloads = {active_shard: task for active_shard, task in self.active_downloads.items() if active_shard.model_id != shard.model_id}
 
     # Start new download
-    download_task = asyncio.create_task(self._download_shard(shard))
+    download_task = asyncio.create_task(self._download_shard(shard, repo_name))
     self.active_downloads[shard] = download_task
     try:
       path = await download_task
@@ -63,14 +65,14 @@ class HFShardDownloader(ShardDownloader):
       if shard in self.active_downloads:
         self.active_downloads.pop(shard)
 
-  async def _download_shard(self, shard: Shard) -> Path:
+  async def _download_shard(self, shard: Shard, repo_name: str) -> Path:
     async def wrapped_progress_callback(event: RepoProgressEvent):
       self._on_progress.trigger_all(shard, event)
 
-    weight_map = await get_weight_map(shard.model_id)
+    weight_map = await get_weight_map(repo_name)
     allow_patterns = get_allow_patterns(weight_map, shard)
 
-    return await download_repo_files(repo_id=shard.model_id, progress_callback=wrapped_progress_callback, allow_patterns=allow_patterns, max_parallel_downloads=self.max_parallel_downloads)
+    return await download_repo_files(repo_name, progress_callback=wrapped_progress_callback, allow_patterns=allow_patterns, max_parallel_downloads=self.max_parallel_downloads)
 
   @property
   def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:

+ 3 - 2
exo/download/shard_download.py

@@ -8,7 +8,7 @@ from exo.helpers import AsyncCallbackSystem
 
 class ShardDownloader(ABC):
   @abstractmethod
-  async def ensure_shard(self, shard: Shard) -> Path:
+  async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
     """
         Ensures that the shard is downloaded.
         Does not allow multiple overlapping downloads at once.
@@ -17,6 +17,7 @@ class ShardDownloader(ABC):
 
         Args:
             shard (Shard): The shard to download.
+            inference_engine_name (str): The inference engine used on the node hosting the shard
         """
     pass
 
@@ -27,7 +28,7 @@ class ShardDownloader(ABC):
 
 
 class NoopShardDownloader(ShardDownloader):
-  async def ensure_shard(self, shard: Shard) -> Path:
+  async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
     return Path("/tmp/noop_shard")
 
   @property

+ 5 - 0
exo/inference/inference_engine.py

@@ -29,6 +29,11 @@ class InferenceEngine(ABC):
     output_data = await self.infer_tensor(request_id, shard, tokens)
     return output_data 
 
+inference_engine_classes = {
+  "mlx": "MLXDynamicShardInferenceEngine",
+  "tinygrad": "TinygradDynamicShardInferenceEngine",
+  "dummy": "DummyInferenceEngine",
+}
 
 def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
   if DEBUG >= 2:

+ 1 - 1
exo/inference/mlx/sharded_inference_engine.py

@@ -62,7 +62,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     if self.shard == shard:
       return
 
-    model_path = await self.shard_downloader.ensure_shard(shard)
+    model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
 
     if self.shard != shard:
       loop = asyncio.get_running_loop()

+ 2 - 2
exo/inference/test_inference_engine.py

@@ -42,7 +42,7 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
   assert np.array_equal(next_resp_full, resp4)
 
 
-asyncio.run(test_inference_engine(MLXDynamicShardInferenceEngine(HFShardDownloader()), MLXDynamicShardInferenceEngine(HFShardDownloader()), "mlx-community/Llama-3.2-1B-Instruct-4bit", 16))
+asyncio.run(test_inference_engine(MLXDynamicShardInferenceEngine(HFShardDownloader()), MLXDynamicShardInferenceEngine(HFShardDownloader()), "llama-3.2-1b", 16))
 
 if os.getenv("RUN_TINYGRAD", default="0") == "1":
   import tinygrad
@@ -50,5 +50,5 @@ if os.getenv("RUN_TINYGRAD", default="0") == "1":
   from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
   tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
   asyncio.run(
-    test_inference_engine(TinygradDynamicShardInferenceEngine(HFShardDownloader()), TinygradDynamicShardInferenceEngine(HFShardDownloader()), "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", 32)
+    test_inference_engine(TinygradDynamicShardInferenceEngine(HFShardDownloader()), TinygradDynamicShardInferenceEngine(HFShardDownloader()), "llama-3-8b", 32)
   )

+ 1 - 1
exo/inference/tinygrad/inference.py

@@ -91,7 +91,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     if self.shard == shard:
       return
 
-    model_path = await self.shard_downloader.ensure_shard(shard)
+    model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
 
     if self.shard != shard:
       loop = asyncio.get_running_loop()

+ 4 - 3
exo/main.py

@@ -23,7 +23,7 @@ from exo.inference.inference_engine import get_inference_engine, InferenceEngine
 from exo.inference.dummy_inference_engine import DummyInferenceEngine
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration.node import Node
-from exo.models import model_base_shards
+from exo.models import build_base_shard, get_repo
 from exo.viz.topology_viz import TopologyViz
 
 # parse args
@@ -175,11 +175,12 @@ async def shutdown(signal, loop):
 
 
 async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
-  shard = model_base_shards.get(model_name, {}).get(inference_engine.__class__.__name__)
+  inference_class = inference_engine.__class__.__name__
+  shard = build_base_shard(model_name, inference_class)
   if not shard:
     print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
     return
-  tokenizer = await resolve_tokenizer(shard.model_id)
+  tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
   request_id = str(uuid.uuid4())
   callback_id = f"cli-wait-response-{request_id}"
   callback = node.on_token.register(callback_id)

+ 101 - 37
exo/models.py

@@ -1,62 +1,126 @@
 from exo.inference.shard import Shard
+from typing import Optional
 
-model_base_shards = {
+model_cards = {
   ### llama
   "llama-3.2-1b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-1B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=16),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="unsloth/Llama-3.2-1B-Instruct", start_layer=0, end_layer=0, n_layers=16),
+    "layers": 16,
+    "repo": { 
+      "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-1B-Instruct-4bit",
+      "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-1B-Instruct",
+    },
   },
   "llama-3.2-3b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-3B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="unsloth/Llama-3.2-3B-Instruct", start_layer=0, end_layer=0, n_layers=28),
+    "layers": 28,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-3B-Instruct-4bit",
+       "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-3B-Instruct",
+    },
   },
   "llama-3.1-8b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32),
+    "layers": 32,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
+       "TinygradDynamicShardInferenceEngine": "mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated",
+    },
   },
   "llama-3.1-70b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B-Instruct", start_layer=0, end_layer=0, n_layers=80),
+    "layers": 80,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit",
+       "TinygradDynamicShardInferenceEngine": "NousResearch/Meta-Llama-3.1-70B-Instruct",
+    },
   },
   "llama-3.1-70b-bf16": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-bf16-CORRECTED", start_layer=0, end_layer=0, n_layers=80),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B-Instruct", start_layer=0, end_layer=0, n_layers=80),
+    "layers": 80,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-70B-Instruct-bf16-CORRECTED",
+       "TinygradDynamicShardInferenceEngine": "NousResearch/Meta-Llama-3.1-70B-Instruct",
+    },
   },
-  "llama-3.1-405b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),},
-  "llama-3.1-405b-8bit": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-Instruct-8bit", start_layer=0, end_layer=0, n_layers=126),},
   "llama-3-8b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32),
+    "layers": 32,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
+       "TinygradDynamicShardInferenceEngine": "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R",
+    },
   },
   "llama-3-70b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80),
+    "layers": 80,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3-70B-Instruct-4bit",
+       "TinygradDynamicShardInferenceEngine": "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R",
+    },
   },
+  "llama-3.1-405b": { "layers": 126, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-405B-4bit", }, },
+  "llama-3.1-405b-8bit": { "layers": 126, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit", }, },
   ### mistral
-  "mistral-nemo": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),},
-  "mistral-large": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),},
+  "mistral-nemo": { "layers": 40, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Mistral-Nemo-Instruct-2407-4bit", }, },
+  "mistral-large": { "layers": 88, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Mistral-Large-Instruct-2407-4bit", }, },
   ### deepseek
-  "deepseek-coder-v2-lite": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27),},
-  "deepseek-coder-v2.5": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", start_layer=0, end_layer=0, n_layers=60),},
+  "deepseek-coder-v2-lite": { "layers": 27, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", }, },
+  "deepseek-coder-v2.5": { "layers": 60, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", }, },
   ### llava
-  "llava-1.5-7b-hf": {"MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32),},
+  "llava-1.5-7b-hf": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "llava-hf/llava-1.5-7b-hf", }, },
   ### qwen
-  "qwen-2.5-coder-1.5b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),},
-  "qwen-2.5-coder-3b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Coder-3B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=36),},
-  "qwen-2.5-coder-7b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),},
-  "qwen-2.5-coder-14b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Coder-14B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=48),},
-  "qwen-2.5-coder-32b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Coder-32B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=64),},
-  "qwen-2.5-7b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-7B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),},
-  "qwen-2.5-math-7b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Math-7B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),},
-  "qwen-2.5-14b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-14B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=48),},
-  "qwen-2.5-72b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-72B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),},
-  "qwen-2.5-math-72b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Math-72B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),},
+  "qwen-2.5-coder-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", }, },
+  "qwen-2.5-coder-3b": { "layers": 36, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-3B-Instruct-4bit", }, },
+  "qwen-2.5-coder-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", }, },
+  "qwen-2.5-coder-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-14B-Instruct-4bit", }, },
+  "qwen-2.5-coder-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-32B-Instruct-4bit", }, },
+  "qwen-2.5-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-7B-Instruct-4bit", }, },
+  "qwen-2.5-math-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-7B-Instruct-4bit", }, },
+  "qwen-2.5-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-14B-Instruct-4bit", }, },
+  "qwen-2.5-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-72B-Instruct-4bit", }, },
+  "qwen-2.5-math-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-72B-Instruct-4bit", }, },
   ### nemotron
-  "nemotron-70b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/nvidia_Llama-3.1-Nemotron-70B-Instruct-HF_4bit", start_layer=0, end_layer=0, n_layers=80),},
-  "nemotron-70b-bf16": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.1-Nemotron-70B-Instruct-HF-bf16", start_layer=0, end_layer=0, n_layers=80),},
+  "nemotron-70b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/nvidia_Llama-3.1-Nemotron-70B-Instruct-HF_4bit", }, },
+  "nemotron-70b-bf16": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.1-Nemotron-70B-Instruct-HF-bf16", }, },
   # gemma
-  "gemma2-9b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/gemma-2-9b-it-4bit", start_layer=0, end_layer=0, n_layers=42),},
-  "gemma2-27b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/gemma-2-27b-it-4bit", start_layer=0, end_layer=0, n_layers=46),},
+  "gemma2-9b": { "layers": 42, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-9b-it-4bit", }, },
+  "gemma2-27b": { "layers": 46, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-27b-it-4bit", }, },
   # dummy
-  "dummy": {"DummyInferenceEngine": Shard(model_id="dummy", start_layer=0, end_layer=7, n_layers=8),},
+  "dummy": { "layers": 8, "repo": { "DummyInferenceEngine": "dummy", }, },
 }
+
+pretty_name = {
+  "llama-3.2-1b": "Llama 3.2 1B",
+  "llama-3.2-3b": "Llama 3.2 3B",
+  "llama-3.1-8b": "Llama 3.1 8B",
+  "llama-3.1-70b": "Llama 3.1 70B",
+  "llama-3.1-70b-bf16": "Llama 3.1 70B (BF16)",
+  "llama-3.1-405b": "Llama 3.1 405B",
+  "llama-3.1-405b-8bit": "Llama 3.1 405B (8-bit)",
+  "gemma2-9b": "Gemma2 9B",
+  "gemma2-27b": "Gemma2 27B",
+  "nemotron-70b": "Nemotron 70B",
+  "nemotron-70b-bf16": "Nemotron 70B (BF16)",
+  "mistral-nemo": "Mistral Nemo",
+  "mistral-large": "Mistral Large",
+  "deepseek-coder-v2-lite": "Deepseek Coder V2 Lite",
+  "deepseek-coder-v2.5": "Deepseek Coder V2.5",
+  "llava-1.5-7b-hf": "LLaVa 1.5 7B (Vision Model)",
+  "qwen-2.5-coder-1.5b": "Qwen 2.5 Coder 1.5B",
+  "qwen-2.5-coder-3b": "Qwen 2.5 Coder 3B",
+  "qwen-2.5-coder-7b": "Qwen 2.5 Coder 7B",
+  "qwen-2.5-coder-14b": "Qwen 2.5 Coder 14B",
+  "qwen-2.5-coder-32b": "Qwen 2.5 Coder 32B",
+  "qwen-2.5-7b": "Qwen 2.5 7B",
+  "qwen-2.5-math-7b": "Qwen 2.5 7B (Math)",
+  "qwen-2.5-14b": "Qwen 2.5 14B",
+  "qwen-2.5-72b": "Qwen 2.5 72B",
+  "qwen-2.5-math-72b": "Qwen 2.5 72B (Math)",
+  "llama-3-8b": "Llama 3 8B",
+  "llama-3-70b": "Llama 3 70B",
+}
+
+def get_repo(model_id: str, inference_engine_classname: str) -> Optional[str]:
+  return model_cards.get(model_id, {}).get("repo", {}).get(inference_engine_classname, None)
+
+def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional[Shard]:
+  repo = get_repo(model_id, inference_engine_classname)
+  n_layers = model_cards.get(model_id, {}).get("layers", 0)
+  if repo is None or n_layers < 1:
+    return None
+  return Shard(model_id, 0, 0, n_layers)
+  

+ 1 - 6
exo/orchestration/standard_node.py

@@ -363,12 +363,7 @@ class StandardNode(Node):
     supported_engines = self.get_supported_inference_engines()
     await self.broadcast_supported_engines(supported_engines)
     if len(self.get_topology_inference_engines()):
-      if any(len(engines) == 1 and "tinygrad" in engines for engines in self.get_topology_inference_engines()):
-        if DEBUG >= 1: print("Found node with only tinygrad, using tinygrad on all nodes")
-        self.inference_engine = get_inference_engine("tinygrad", self.shard_downloader)
-      else:
-        if DEBUG >= 1: print("All nodes can use mlx, using mlx for inference")
-        self.inference_engine = get_inference_engine("mlx", self.shard_downloader)
+      self.inference_engine = get_inference_engine(supported_engines[0], self.shard_downloader)
 
   async def periodic_topology_collection(self, interval: int):
     while True:

+ 2 - 30
exo/tinychat/index.html

@@ -29,36 +29,8 @@
     <div x-show="errorMessage" x-transition.opacity x-text="errorMessage" class="toast">
     </div>
 <div class="model-selector">
-<select @change="if (cstate) cstate.selectedModel = $event.target.value" x-model="cstate.selectedModel">
-<option value="llama-3.2-1b">Llama 3.2 1B</option>
-<option value="llama-3.2-3b">Llama 3.2 3B</option>
-<option value="llama-3.1-8b">Llama 3.1 8B</option>
-<option value="llama-3.1-70b">Llama 3.1 70B</option>
-<option value="llama-3.1-70b-bf16">Llama 3.1 70B (BF16)</option>
-<option value="llama-3.1-405b">Llama 3.1 405B</option>
-<option value="llama-3.1-405b-8bit">Llama 3.1 405B (8-bit)</option>
-<option value="gemma2-9b">Gemma2 9B</option>
-<option value="gemma2-27b">Gemma2 27B</option>
-<option value="nemotron-70b">Nemotron 70B</option>
-<option value="nemotron-70b-bf16">Nemotron 70B (BF16)</option>
-<option value="mistral-nemo">Mistral Nemo</option>
-<option value="mistral-large">Mistral Large</option>
-<option value="deepseek-coder-v2-lite">Deepseek Coder V2 Lite</option>
-<option value="deepseek-coder-v2.5">Deepseek Coder V2.5</option>
-<option value="llava-1.5-7b-hf">LLaVa 1.5 7B (Vision Model)</option>
-<option value="qwen-2.5-coder-1.5b">Qwen 2.5 Coder 1.5B</option>
-<option value="qwen-2.5-coder-3b">Qwen 2.5 Coder 3B</option>
-<option value="qwen-2.5-coder-7b">Qwen 2.5 Coder 7B</option>
-<option value="qwen-2.5-coder-14b">Qwen 2.5 Coder 14B</option>
-<option value="qwen-2.5-coder-32b">Qwen 2.5 Coder 32B</option>
-<option value="qwen-2.5-7b">Qwen 2.5 7B</option>
-<option value="qwen-2.5-math-7b">Qwen 2.5 7B (Math)</option>
-<option value="qwen-2.5-14b">Qwen 2.5 14B</option>
-<option value="qwen-2.5-72b">Qwen 2.5 72B</option>
-<option value="qwen-2.5-math-72b">Qwen 2.5 72B (Math)</option>
-<option value="llama-3-8b">Llama 3 8B</option>
-<option value="llama-3-70b">Llama 3 70B</option>
-</select>
+  <select @change="if (cstate) cstate.selectedModel = $event.target.value" x-model="cstate.selectedModel" x-init="await populateSelector()" class='model-select'>
+  </select>
 </div>
 <div @popstate.window="
       if (home === 2) {

+ 51 - 0
exo/tinychat/index.js

@@ -72,6 +72,56 @@ document.addEventListener("alpine:init", () => {
       return `${s}s`;
     },
 
+    async populateSelector() {
+      try {
+        const response = await fetch(`${window.location.origin}/modelpool`);
+        const responseText = await response.text(); // Get raw response text first
+        
+        if (!response.ok) {
+          throw new Error(`HTTP error! status: ${response.status}`);
+        }
+        
+        // Try to parse the response text
+        let responseJson;
+        try {
+          responseJson = JSON.parse(responseText);
+        } catch (parseError) {
+          console.error('Failed to parse JSON:', parseError);
+          throw new Error(`Invalid JSON response: ${responseText}`);
+        }
+
+        const sel = document.querySelector(".model-select");
+        if (!sel) {
+          throw new Error("Could not find model selector element");
+        }
+
+        // Clear the current options and add new ones
+        sel.innerHTML = '';
+          
+        const modelDict = responseJson["model pool"];
+        if (!modelDict) {
+          throw new Error("Response missing 'model pool' property");
+        }
+
+        Object.entries(modelDict).forEach(([key, value]) => {
+          const opt = document.createElement("option");
+          opt.value = key;
+          opt.textContent = value;
+          sel.appendChild(opt);
+        });
+
+        // Set initial value to the first model
+        const firstKey = Object.keys(modelDict)[0];
+        if (firstKey) {
+          sel.value = firstKey;
+          this.cstate.selectedModel = firstKey;
+        }
+      } catch (error) {
+        console.error("Error populating model selector:", error);
+        this.errorMessage = `Failed to load models: ${error.message}`;
+      }
+    },
+
     async handleImageUpload(event) {
       const file = event.target.files[0];
       if (file) {
@@ -535,6 +585,7 @@ function createParser(onParse) {
     }
   }
 }
+
 const BOM = [239, 187, 191];
 function hasBom(buffer) {
   return BOM.every((charCode, index) => buffer.charCodeAt(index) === charCode);