Browse Source

Enabled inference engine intercompatibility

If inference engines can run the same model, nodes on those inference engines can now interoperate
Nel Nibcord 7 months ago
parent
commit
d69a9c4d43

+ 9 - 9
exo/api/chatgpt_api.py

@@ -14,7 +14,7 @@ from exo.helpers import PrefixDict
 from exo.inference.shard import Shard
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration import Node
-from exo.models import model_base_shards
+from exo.models import build_base_shard, model_cards, get_repo
 from typing import Callable
 
 
@@ -199,13 +199,13 @@ class ChatGPTAPI:
     return web.FileResponse(self.static_dir/"index.html")
 
   async def handle_get_models(self, request):
-    return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_base_shards.items()])
+    return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])
 
   async def handle_post_chat_token_encode(self, request):
     data = await request.json()
-    shard = model_base_shards.get(data.get("model", self.default_model), {}).get(self.inference_engine_classname)
+    shard = build_base_shard(self.default_model, self.inference_engine_classname)
     messages = [parse_message(msg) for msg in data.get("messages", [])]
-    tokenizer = await resolve_tokenizer(shard.model_id)
+    tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
     return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})
 
   async def handle_get_download_progress(self, request):
@@ -224,18 +224,18 @@ class ChatGPTAPI:
     chat_request = parse_chat_request(data, self.default_model)
     if chat_request.model and chat_request.model.startswith("gpt-"):  # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
       chat_request.model = self.default_model if self.default_model.startswith("llama") else "llama-3.2-1b"
-    if not chat_request.model or chat_request.model not in model_base_shards:
-      if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_base_shards.keys())}. Defaulting to {self.default_model}")
+    if not chat_request.model or chat_request.model not in model_cards:
+      if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
       chat_request.model = self.default_model
-    shard = model_base_shards[chat_request.model].get(self.inference_engine_classname, None)
+    shard = build_base_shard(chat_request.model, self.inference_engine_classname)
     if not shard:
-      supported_models = [model for model, engines in model_base_shards.items() if self.inference_engine_classname in engines]
+      supported_models = [model for model, info in model_cards.items() if self.inference_engine_classname in info.get("repo", {})]
       return web.json_response(
         {"detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"},
         status=400,
       )
 
-    tokenizer = await resolve_tokenizer(shard.model_id)
+    tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
     if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
 
     prompt = build_prompt(tokenizer, chat_request.messages)

+ 8 - 6
exo/download/hf/hf_shard_download.py

@@ -7,6 +7,7 @@ from exo.download.shard_download import ShardDownloader
 from exo.download.download_progress import RepoProgressEvent
 from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_weight_map, get_allow_patterns, get_repo_root
 from exo.helpers import AsyncCallbackSystem, DEBUG
+from exo.models import model_cards, get_repo
 
 
 class HFShardDownloader(ShardDownloader):
@@ -17,11 +18,12 @@ class HFShardDownloader(ShardDownloader):
     self.completed_downloads: Dict[Shard, Path] = {}
     self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
 
-  async def ensure_shard(self, shard: Shard) -> Path:
+  async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
+    repo_name = get_repo(shard.model_id, inference_engine_name)
     if shard in self.completed_downloads:
       return self.completed_downloads[shard]
     if self.quick_check:
-      repo_root = get_repo_root(shard.model_id)
+      repo_root = get_repo_root(repo_name)
       snapshots_dir = repo_root/"snapshots"
       if snapshots_dir.exists():
         visible_dirs = [d for d in snapshots_dir.iterdir() if not d.name.startswith('.')]
@@ -51,7 +53,7 @@ class HFShardDownloader(ShardDownloader):
     self.active_downloads = {active_shard: task for active_shard, task in self.active_downloads.items() if active_shard.model_id != shard.model_id}
 
     # Start new download
-    download_task = asyncio.create_task(self._download_shard(shard))
+    download_task = asyncio.create_task(self._download_shard(shard, repo_name))
     self.active_downloads[shard] = download_task
     try:
       path = await download_task
@@ -63,14 +65,14 @@ class HFShardDownloader(ShardDownloader):
       if shard in self.active_downloads:
         self.active_downloads.pop(shard)
 
-  async def _download_shard(self, shard: Shard) -> Path:
+  async def _download_shard(self, shard: Shard, repo_name: str) -> Path:
     async def wrapped_progress_callback(event: RepoProgressEvent):
       self._on_progress.trigger_all(shard, event)
 
-    weight_map = await get_weight_map(shard.model_id)
+    weight_map = await get_weight_map(repo_name)
     allow_patterns = get_allow_patterns(weight_map, shard)
 
-    return await download_repo_files(repo_id=shard.model_id, progress_callback=wrapped_progress_callback, allow_patterns=allow_patterns, max_parallel_downloads=self.max_parallel_downloads)
+    return await download_repo_files(repo_name, progress_callback=wrapped_progress_callback, allow_patterns=allow_patterns, max_parallel_downloads=self.max_parallel_downloads)
 
   @property
   def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:

+ 3 - 2
exo/download/shard_download.py

@@ -8,7 +8,7 @@ from exo.helpers import AsyncCallbackSystem
 
 class ShardDownloader(ABC):
   @abstractmethod
-  async def ensure_shard(self, shard: Shard) -> Path:
+  async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
     """
         Ensures that the shard is downloaded.
         Does not allow multiple overlapping downloads at once.
@@ -17,6 +17,7 @@ class ShardDownloader(ABC):
 
         Args:
             shard (Shard): The shard to download.
+            inference_engine_name (str): The inference engine used on the node hosting the shard
         """
     pass
 
@@ -27,7 +28,7 @@ class ShardDownloader(ABC):
 
 
 class NoopShardDownloader(ShardDownloader):
-  async def ensure_shard(self, shard: Shard) -> Path:
+  async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
     return Path("/tmp/noop_shard")
 
   @property

+ 1 - 1
exo/inference/mlx/sharded_inference_engine.py

@@ -62,7 +62,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     if self.shard == shard:
       return
 
-    model_path = await self.shard_downloader.ensure_shard(shard)
+    model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
 
     if self.shard != shard:
       loop = asyncio.get_running_loop()

+ 1 - 1
exo/inference/tinygrad/inference.py

@@ -91,7 +91,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     if self.shard == shard:
       return
 
-    model_path = await self.shard_downloader.ensure_shard(shard)
+    model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
 
     if self.shard != shard:
       loop = asyncio.get_running_loop()

+ 4 - 3
exo/main.py

@@ -23,7 +23,7 @@ from exo.inference.inference_engine import get_inference_engine, InferenceEngine
 from exo.inference.dummy_inference_engine import DummyInferenceEngine
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration.node import Node
-from exo.models import model_base_shards
+from exo.models import build_base_shard, get_repo
 from exo.viz.topology_viz import TopologyViz
 
 # parse args
@@ -175,11 +175,12 @@ async def shutdown(signal, loop):
 
 
 async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
-  shard = model_base_shards.get(model_name, {}).get(inference_engine.__class__.__name__)
+  inference_class = inference_engine.__class__.__name__
+  shard = build_base_shard(model_name, inference_class)
   if not shard:
     print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
     return
-  tokenizer = await resolve_tokenizer(shard.model_id)
+  tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
   request_id = str(uuid.uuid4())
   callback_id = f"cli-wait-response-{request_id}"
   callback = node.on_token.register(callback_id)

+ 70 - 37
exo/models.py

@@ -1,62 +1,95 @@
 from exo.inference.shard import Shard
+from typing import Optional
 
-model_base_shards = {
+model_cards = {
   ### llama
   "llama-3.2-1b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-1B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=16),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="unsloth/Llama-3.2-1B-Instruct", start_layer=0, end_layer=0, n_layers=16),
+    "layers": 16,
+    "repo": { 
+      "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-1B-Instruct-4bit",
+      "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-1B-Instruct",
+    },
   },
   "llama-3.2-3b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-3B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="unsloth/Llama-3.2-3B-Instruct", start_layer=0, end_layer=0, n_layers=28),
+    "layers": 28,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-3B-Instruct-4bit",
+       "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-3B-Instruct",
+    },
   },
   "llama-3.1-8b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32),
+    "layers": 32,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
+       "TinygradDynamicShardInferenceEngine": "mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated",
+    },
   },
   "llama-3.1-70b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B-Instruct", start_layer=0, end_layer=0, n_layers=80),
+    "layers": 80,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit",
+       "TinygradDynamicShardInferenceEngine": "NousResearch/Meta-Llama-3.1-70B-Instruct",
+    },
   },
   "llama-3.1-70b-bf16": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-bf16-CORRECTED", start_layer=0, end_layer=0, n_layers=80),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B-Instruct", start_layer=0, end_layer=0, n_layers=80),
+    "layers": 80,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-70B-Instruct-bf16-CORRECTED",
+       "TinygradDynamicShardInferenceEngine": "NousResearch/Meta-Llama-3.1-70B-Instruct",
+    },
   },
-  "llama-3.1-405b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),},
-  "llama-3.1-405b-8bit": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-Instruct-8bit", start_layer=0, end_layer=0, n_layers=126),},
   "llama-3-8b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32),
+    "layers": 32,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
+       "TinygradDynamicShardInferenceEngine": "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R",
+    },
   },
   "llama-3-70b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80),
+    "layers": 80,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3-70B-Instruct-4bit",
+       "TinygradDynamicShardInferenceEngine": "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R",
+    },
   },
+  "llama-3.1-405b": { "layers": 126, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-405B-4bit", }, },
+  "llama-3.1-405b-8bit": { "layers": 126, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit", }, },
   ### mistral
-  "mistral-nemo": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),},
-  "mistral-large": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),},
+  "mistral-nemo": { "layers": 40, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Mistral-Nemo-Instruct-2407-4bit", }, },
+  "mistral-large": { "layers": 88, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Mistral-Large-Instruct-2407-4bit", }, },
   ### deepseek
-  "deepseek-coder-v2-lite": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27),},
-  "deepseek-coder-v2.5": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", start_layer=0, end_layer=0, n_layers=60),},
+  "deepseek-coder-v2-lite": { "layers": 27, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", }, },
+  "deepseek-coder-v2.5": { "layers": 60, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", }, },
   ### llava
-  "llava-1.5-7b-hf": {"MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32),},
+  "llava-1.5-7b-hf": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "llava-hf/llava-1.5-7b-hf", }, },
   ### qwen
-  "qwen-2.5-coder-1.5b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),},
-  "qwen-2.5-coder-3b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Coder-3B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=36),},
-  "qwen-2.5-coder-7b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),},
-  "qwen-2.5-coder-14b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Coder-14B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=48),},
-  "qwen-2.5-coder-32b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Coder-32B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=64),},
-  "qwen-2.5-7b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-7B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),},
-  "qwen-2.5-math-7b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Math-7B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),},
-  "qwen-2.5-14b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-14B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=48),},
-  "qwen-2.5-72b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-72B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),},
-  "qwen-2.5-math-72b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Math-72B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),},
+  "qwen-2.5-coder-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", }, },
+  "qwen-2.5-coder-3b": { "layers": 36, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-3B-Instruct-4bit", }, },
+  "qwen-2.5-coder-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", }, },
+  "qwen-2.5-coder-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-14B-Instruct-4bit", }, },
+  "qwen-2.5-coder-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-32B-Instruct-4bit", }, },
+  "qwen-2.5-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-7B-Instruct-4bit", }, },
+  "qwen-2.5-math-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-7B-Instruct-4bit", }, },
+  "qwen-2.5-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-14B-Instruct-4bit", }, },
+  "qwen-2.5-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-72B-Instruct-4bit", }, },
+  "qwen-2.5-math-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-72B-Instruct-4bit", }, },
   ### nemotron
-  "nemotron-70b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/nvidia_Llama-3.1-Nemotron-70B-Instruct-HF_4bit", start_layer=0, end_layer=0, n_layers=80),},
-  "nemotron-70b-bf16": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.1-Nemotron-70B-Instruct-HF-bf16", start_layer=0, end_layer=0, n_layers=80),},
+  "nemotron-70b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/nvidia_Llama-3.1-Nemotron-70B-Instruct-HF_4bit", }, },
+  "nemotron-70b-bf16": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.1-Nemotron-70B-Instruct-HF-bf16", }, },
   # gemma
-  "gemma2-9b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/gemma-2-9b-it-4bit", start_layer=0, end_layer=0, n_layers=42),},
-  "gemma2-27b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/gemma-2-27b-it-4bit", start_layer=0, end_layer=0, n_layers=46),},
+  "gemma2-9b": { "layers": 42, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-9b-it-4bit", }, },
+  "gemma2-27b": { "layers": 46, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-27b-it-4bit", }, },
   # dummy
-  "dummy": {"DummyInferenceEngine": Shard(model_id="dummy", start_layer=0, end_layer=7, n_layers=8),},
+  "dummy": { "layers": 8, "repo": { "DummyInferenceEngine": "dummy", }, },
 }
+
+def get_repo(model_id: str, inference_engine_classname: str) -> Optional[str]:
+  return model_cards.get(model_id, {}).get("repo", {}).get(inference_engine_classname, None)
+
+def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional[Shard]:
+  repo = get_repo(model_id, inference_engine_classname)
+  n_layers = model_cards.get(model_id, {}).get("layers", 0)
+  if repo is None or n_layers < 1:
+    return None
+  return Shard(model_id, 0, 0, n_layers)
+  

+ 1 - 6
exo/orchestration/standard_node.py

@@ -363,12 +363,7 @@ class StandardNode(Node):
     supported_engines = self.get_supported_inference_engines()
     await self.broadcast_supported_engines(supported_engines)
     if len(self.get_topology_inference_engines()):
-      if any(len(engines) == 1 and "tinygrad" in engines for engines in self.get_topology_inference_engines()):
-        if DEBUG >= 1: print("Found node with only tinygrad, using tinygrad on all nodes")
-        self.inference_engine = get_inference_engine("tinygrad", self.shard_downloader)
-      else:
-        if DEBUG >= 1: print("All nodes can use mlx, using mlx for inference")
-        self.inference_engine = get_inference_engine("mlx", self.shard_downloader)
+      self.inference_engine = get_inference_engine(supported_engines[0], self.shard_downloader)
 
   async def periodic_topology_collection(self, interval: int):
     while True: