瀏覽代碼

Merge pull request #377 from josh1593/autoselection-of-inference-engines

Autoselection of inference engines
Alex Cheema 6 月之前
父節點
當前提交
9dc2932c39
共有 2 個文件被更改,包括 46 次插入5 次删除
  1. 4 2
      exo/main.py
  2. 42 3
      exo/orchestration/standard_node.py

+ 4 - 2
exo/main.py

@@ -2,6 +2,7 @@ import argparse
 import asyncio
 import signal
 import json
+import logging
 import time
 import traceback
 import uuid
@@ -51,6 +52,7 @@ args = parser.parse_args()
 
 print_yellow_exo()
 
+
 system_info = get_system_info()
 print(f"Detected system: {system_info}")
 
@@ -86,7 +88,8 @@ node = StandardNode(
   discovery,
   partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
   max_generate_tokens=args.max_generate_tokens,
-  topology_viz=topology_viz
+  topology_viz=topology_viz,
+  shard_downloader=shard_downloader
 )
 server = GRPCServer(node, args.node_host, args.node_port)
 node.server = server
@@ -144,7 +147,6 @@ async def shutdown(signal, loop):
   await server.stop()
   loop.stop()
 
-
 async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
   shard = model_base_shards.get(model_name, {}).get(inference_engine.__class__.__name__)
   if not shard:

+ 42 - 3
exo/orchestration/standard_node.py

@@ -4,7 +4,7 @@ import asyncio
 import uuid
 import time
 import traceback
-from typing import List, Dict, Optional, Tuple, Union
+from typing import List, Dict, Optional, Tuple, Union, Set
 from exo.networking import Discovery, PeerHandle, Server
 from exo.inference.inference_engine import InferenceEngine, Shard
 from .node import Node
@@ -15,7 +15,8 @@ from exo import DEBUG
 from exo.helpers import AsyncCallbackSystem
 from exo.viz.topology_viz import TopologyViz
 from exo.download.hf.hf_helpers import RepoProgressEvent
-
+from exo.inference.inference_engine import get_inference_engine, InferenceEngine
+from exo.download.hf.hf_shard_download import HFShardDownloader
 
 class StandardNode(Node):
   def __init__(
@@ -27,6 +28,7 @@ class StandardNode(Node):
     partitioning_strategy: PartitioningStrategy = None,
     max_generate_tokens: int = 1024,
     topology_viz: Optional[TopologyViz] = None,
+    shard_downloader: Optional[HFShardDownloader] = None,
   ):
     self.id = _id
     self.inference_engine = inference_engine
@@ -43,6 +45,8 @@ class StandardNode(Node):
     self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
     self._on_opaque_status.register("node_status").on_next(self.on_node_status)
     self.node_download_progress: Dict[str, RepoProgressEvent] = {}
+    self.topology_inference_engines_pool: List[List[str]] = []
+    self.shard_downloader = shard_downloader
 
   async def start(self, wait_for_peers: int = 0) -> None:
     await self.server.start()
@@ -59,6 +63,10 @@ class StandardNode(Node):
   def on_node_status(self, request_id, opaque_status):
     try:
       status_data = json.loads(opaque_status)
+      if status_data.get("type", "") == "supported_inference_engines":
+        node_id = status_data.get("node_id")
+        engines = status_data.get("engines", [])
+        self.topology_inference_engines_pool.append(engines)
       if status_data.get("type", "") == "node_status":
         if status_data.get("status", "").startswith("start_"):
           self.current_topology.active_node_id = status_data.get("node_id")
@@ -76,6 +84,26 @@ class StandardNode(Node):
       if DEBUG >= 1: print(f"Error updating visualization: {e}")
       if DEBUG >= 1: traceback.print_exc()
 
+  def get_supported_inference_engines(self):
+    supported_engine_names = []
+    if self.inference_engine.__class__.__name__ == 'MLXDynamicShardInferenceEngine':
+        supported_engine_names.append('mlx')
+        supported_engine_names.append('tinygrad')
+    else:
+        supported_engine_names.append('tinygrad')
+    return supported_engine_names
+
+  async def broadcast_supported_engines(self, supported_engines_names: List[str]):
+    status_message = json.dumps({
+        "type": "supported_inference_engines",
+        "node_id": self.id,
+        "engines": supported_engines_names
+    })
+    await self.broadcast_opaque_status("", status_message)
+
+  def get_topology_inference_engines(self) -> List[List[str]]:
+    return self.topology_inference_engines_pool
+
   async def process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
     shard = self.get_current_shard(base_shard)
     asyncio.create_task(
@@ -338,6 +366,17 @@ class StandardNode(Node):
     self.peers = next_peers
     return len(peers_added) > 0 or len(peers_removed) > 0 or len(peers_updated) > 0
 
+  async def select_best_inference_engine(self):
+    supported_engines = self.get_supported_inference_engines()
+    await self.broadcast_supported_engines(supported_engines)
+    if len(self.get_topology_inference_engines()):
+      if any(len(engines) == 1 and "tinygrad" in engines for engines in self.get_topology_inference_engines()):
+        if DEBUG >= 1: print("Found node with only tinygrad, using tinygrad on all nodes")
+        self.inference_engine = get_inference_engine("tinygrad", self.shard_downloader)
+      else:
+        if DEBUG >= 1: print("All nodes can use mlx, using mlx for inference")
+        self.inference_engine = get_inference_engine("mlx", self.shard_downloader) 
+
   async def periodic_topology_collection(self, interval: int):
     while True:
       await asyncio.sleep(interval)
@@ -346,6 +385,7 @@ class StandardNode(Node):
         if DEBUG >= 2: print(f"{did_peers_change=}")
         if did_peers_change:
           await self.collect_topology()
+          await self.select_best_inference_engine()
       except Exception as e:
         print(f"Error collecting topology: {e}")
         traceback.print_exc()
@@ -424,7 +464,6 @@ class StandardNode(Node):
       except Exception as e:
         print(f"Error sending opaque status to {peer.id()}: {e}")
         traceback.print_exc()
-
     await asyncio.gather(*[send_status_to_peer(peer) for peer in self.peers], return_exceptions=True)
     # in the case of opaque status, we also want to receive our own opaque statuses
     self.on_opaque_status.trigger_all(request_id, status)