|
@@ -4,7 +4,7 @@ import asyncio
|
|
|
import uuid
|
|
|
import time
|
|
|
import traceback
|
|
|
-from typing import List, Dict, Optional, Tuple, Union
|
|
|
+from typing import List, Dict, Optional, Tuple, Union, Set
|
|
|
from exo.networking import Discovery, PeerHandle, Server
|
|
|
from exo.inference.inference_engine import InferenceEngine, Shard
|
|
|
from .node import Node
|
|
@@ -15,7 +15,8 @@ from exo import DEBUG
|
|
|
from exo.helpers import AsyncCallbackSystem
|
|
|
from exo.viz.topology_viz import TopologyViz
|
|
|
from exo.download.hf.hf_helpers import RepoProgressEvent
|
|
|
-
|
|
|
+from exo.inference.inference_engine import get_inference_engine, InferenceEngine
|
|
|
+from exo.download.hf.hf_shard_download import HFShardDownloader
|
|
|
|
|
|
class StandardNode(Node):
|
|
|
def __init__(
|
|
@@ -27,6 +28,7 @@ class StandardNode(Node):
|
|
|
partitioning_strategy: PartitioningStrategy = None,
|
|
|
max_generate_tokens: int = 1024,
|
|
|
topology_viz: Optional[TopologyViz] = None,
|
|
|
+ shard_downloader: Optional[HFShardDownloader] = None,
|
|
|
):
|
|
|
self.id = _id
|
|
|
self.inference_engine = inference_engine
|
|
@@ -43,6 +45,8 @@ class StandardNode(Node):
|
|
|
self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
|
|
|
self._on_opaque_status.register("node_status").on_next(self.on_node_status)
|
|
|
self.node_download_progress: Dict[str, RepoProgressEvent] = {}
|
|
|
+ self.topology_inference_engines_pool: List[List[str]] = []
|
|
|
+ self.shard_downloader = shard_downloader
|
|
|
|
|
|
async def start(self, wait_for_peers: int = 0) -> None:
|
|
|
await self.server.start()
|
|
@@ -59,6 +63,10 @@ class StandardNode(Node):
|
|
|
def on_node_status(self, request_id, opaque_status):
|
|
|
try:
|
|
|
status_data = json.loads(opaque_status)
|
|
|
+ if status_data.get("type", "") == "supported_inference_engines":
|
|
|
+ node_id = status_data.get("node_id")
|
|
|
+ engines = status_data.get("engines", [])
|
|
|
+ self.topology_inference_engines_pool.append(engines)
|
|
|
if status_data.get("type", "") == "node_status":
|
|
|
if status_data.get("status", "").startswith("start_"):
|
|
|
self.current_topology.active_node_id = status_data.get("node_id")
|
|
@@ -76,6 +84,26 @@ class StandardNode(Node):
|
|
|
if DEBUG >= 1: print(f"Error updating visualization: {e}")
|
|
|
if DEBUG >= 1: traceback.print_exc()
|
|
|
|
|
|
+ def get_supported_inference_engines(self):
|
|
|
+ supported_engine_names = []
|
|
|
+ if self.inference_engine.__class__.__name__ == 'MLXDynamicShardInferenceEngine':
|
|
|
+ supported_engine_names.append('mlx')
|
|
|
+ supported_engine_names.append('tinygrad')
|
|
|
+ else:
|
|
|
+ supported_engine_names.append('tinygrad')
|
|
|
+ return supported_engine_names
|
|
|
+
|
|
|
+ async def broadcast_supported_engines(self, supported_engines_names: List[str]):
|
|
|
+ status_message = json.dumps({
|
|
|
+ "type": "supported_inference_engines",
|
|
|
+ "node_id": self.id,
|
|
|
+ "engines": supported_engines_names
|
|
|
+ })
|
|
|
+ await self.broadcast_opaque_status("", status_message)
|
|
|
+
|
|
|
+ def get_topology_inference_engines(self) -> List[List[str]]:
|
|
|
+ return self.topology_inference_engines_pool
|
|
|
+
|
|
|
async def process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
|
|
|
shard = self.get_current_shard(base_shard)
|
|
|
asyncio.create_task(
|
|
@@ -338,6 +366,17 @@ class StandardNode(Node):
|
|
|
self.peers = next_peers
|
|
|
return len(peers_added) > 0 or len(peers_removed) > 0 or len(peers_updated) > 0
|
|
|
|
|
|
+ async def select_best_inference_engine(self):
|
|
|
+ supported_engines = self.get_supported_inference_engines()
|
|
|
+ await self.broadcast_supported_engines(supported_engines)
|
|
|
+ if len(self.get_topology_inference_engines()):
|
|
|
+ if any(len(engines) == 1 and "tinygrad" in engines for engines in self.get_topology_inference_engines()):
|
|
|
+ if DEBUG >= 1: print("Found node with only tinygrad, using tinygrad on all nodes")
|
|
|
+ self.inference_engine = get_inference_engine("tinygrad", self.shard_downloader)
|
|
|
+ else:
|
|
|
+ if DEBUG >= 1: print("All nodes can use mlx, using mlx for inference")
|
|
|
+ self.inference_engine = get_inference_engine("mlx", self.shard_downloader)
|
|
|
+
|
|
|
async def periodic_topology_collection(self, interval: int):
|
|
|
while True:
|
|
|
await asyncio.sleep(interval)
|
|
@@ -346,6 +385,7 @@ class StandardNode(Node):
|
|
|
if DEBUG >= 2: print(f"{did_peers_change=}")
|
|
|
if did_peers_change:
|
|
|
await self.collect_topology()
|
|
|
+ await self.select_best_inference_engine()
|
|
|
except Exception as e:
|
|
|
print(f"Error collecting topology: {e}")
|
|
|
traceback.print_exc()
|
|
@@ -424,7 +464,6 @@ class StandardNode(Node):
|
|
|
except Exception as e:
|
|
|
print(f"Error sending opaque status to {peer.id()}: {e}")
|
|
|
traceback.print_exc()
|
|
|
-
|
|
|
await asyncio.gather(*[send_status_to_peer(peer) for peer in self.peers], return_exceptions=True)
|
|
|
# in the case of opaque status, we also want to receive our own opaque statuses
|
|
|
self.on_opaque_status.trigger_all(request_id, status)
|