|
@@ -4,7 +4,7 @@ import asyncio
|
|
import uuid
|
|
import uuid
|
|
import time
|
|
import time
|
|
import traceback
|
|
import traceback
|
|
-from typing import List, Dict, Optional, Tuple, Union
|
|
|
|
|
|
+from typing import List, Dict, Optional, Tuple, Union, Set
|
|
from exo.networking import Discovery, PeerHandle, Server
|
|
from exo.networking import Discovery, PeerHandle, Server
|
|
from exo.inference.inference_engine import InferenceEngine, Shard
|
|
from exo.inference.inference_engine import InferenceEngine, Shard
|
|
from .node import Node
|
|
from .node import Node
|
|
@@ -15,6 +15,8 @@ from exo import DEBUG
|
|
from exo.helpers import AsyncCallbackSystem
|
|
from exo.helpers import AsyncCallbackSystem
|
|
from exo.viz.topology_viz import TopologyViz
|
|
from exo.viz.topology_viz import TopologyViz
|
|
from exo.download.hf.hf_helpers import RepoProgressEvent
|
|
from exo.download.hf.hf_helpers import RepoProgressEvent
|
|
|
|
+from exo.inference.inference_engine import get_inference_engine, InferenceEngine
|
|
|
|
+from exo.download.hf.hf_shard_download import HFShardDownloader
|
|
import logging
|
|
import logging
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger = logging.getLogger(__name__)
|
|
@@ -29,6 +31,7 @@ class StandardNode(Node):
|
|
partitioning_strategy: PartitioningStrategy = None,
|
|
partitioning_strategy: PartitioningStrategy = None,
|
|
max_generate_tokens: int = 1024,
|
|
max_generate_tokens: int = 1024,
|
|
topology_viz: Optional[TopologyViz] = None,
|
|
topology_viz: Optional[TopologyViz] = None,
|
|
|
|
+ shard_downloader: Optional[HFShardDownloader] = None,
|
|
):
|
|
):
|
|
self.id = _id
|
|
self.id = _id
|
|
self.inference_engine = inference_engine
|
|
self.inference_engine = inference_engine
|
|
@@ -45,7 +48,8 @@ class StandardNode(Node):
|
|
self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
|
|
self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
|
|
self._on_opaque_status.register("node_status").on_next(self.on_node_status)
|
|
self._on_opaque_status.register("node_status").on_next(self.on_node_status)
|
|
self.node_download_progress: Dict[str, RepoProgressEvent] = {}
|
|
self.node_download_progress: Dict[str, RepoProgressEvent] = {}
|
|
- self.topology_inference_engines_pool: List[str] = []
|
|
|
|
|
|
+ self.topology_inference_engines_pool: List[List[str]] = []
|
|
|
|
+ self.shard_downloader = shard_downloader
|
|
|
|
|
|
async def start(self, wait_for_peers: int = 0) -> None:
|
|
async def start(self, wait_for_peers: int = 0) -> None:
|
|
await self.server.start()
|
|
await self.server.start()
|
|
@@ -62,6 +66,14 @@ class StandardNode(Node):
|
|
def on_node_status(self, request_id, opaque_status):
|
|
def on_node_status(self, request_id, opaque_status):
|
|
try:
|
|
try:
|
|
status_data = json.loads(opaque_status)
|
|
status_data = json.loads(opaque_status)
|
|
|
|
+ if status_data.get("type", "") == "supported_inference_engines":
|
|
|
|
+ logger.error(f'supported_inference_engines: {status_data}')
|
|
|
|
+ logger.error('inside on_status data')
|
|
|
|
+ node_id = status_data.get("node_id")
|
|
|
|
+ engines = status_data.get("engines", [])
|
|
|
|
+ logger.error(f'engines: {engines}')
|
|
|
|
+ self.topology_inference_engines_pool.append(engines)
|
|
|
|
+ logger.error(f'topology_inference_engines_pool: {self.topology_inference_engines_pool}')
|
|
if status_data.get("type", "") == "node_status":
|
|
if status_data.get("type", "") == "node_status":
|
|
if status_data.get("status", "").startswith("start_"):
|
|
if status_data.get("status", "").startswith("start_"):
|
|
self.current_topology.active_node_id = status_data.get("node_id")
|
|
self.current_topology.active_node_id = status_data.get("node_id")
|
|
@@ -82,22 +94,24 @@ class StandardNode(Node):
|
|
def get_supported_inference_engines(self):
|
|
def get_supported_inference_engines(self):
|
|
supported_engine_names = []
|
|
supported_engine_names = []
|
|
if self.inference_engine.__class__.__name__ == 'MLXDynamicShardInferenceEngine':
|
|
if self.inference_engine.__class__.__name__ == 'MLXDynamicShardInferenceEngine':
|
|
- supported_engine_names.extend(['mlx', 'tinygrad'])
|
|
|
|
|
|
+ supported_engine_names.append('mlx')
|
|
|
|
+ supported_engine_names.append('tinygrad')
|
|
else:
|
|
else:
|
|
- supported_engine_names.append('tinygrad')
|
|
|
|
|
|
+ supported_engine_names.append('tinygrad')
|
|
return supported_engine_names
|
|
return supported_engine_names
|
|
|
|
|
|
- async def broadcast_supported_engines(self, supported_engines: List):
|
|
|
|
|
|
+ async def broadcast_supported_engines(self, supported_engines_names: List[str]):
|
|
status_message = json.dumps({
|
|
status_message = json.dumps({
|
|
- "type": "supported_inference_engines",
|
|
|
|
- "node_id": self.id,
|
|
|
|
- "engines": supported_engines
|
|
|
|
|
|
+ "type": "supported_inference_engines",
|
|
|
|
+ "node_id": self.id,
|
|
|
|
+ "engines": supported_engines_names
|
|
})
|
|
})
|
|
logger.error(f'broadcast_supported_engines: {status_message}')
|
|
logger.error(f'broadcast_supported_engines: {status_message}')
|
|
await self.broadcast_opaque_status("", status_message)
|
|
await self.broadcast_opaque_status("", status_message)
|
|
logger.error(f'broadcast_supported_engines: done')
|
|
logger.error(f'broadcast_supported_engines: done')
|
|
|
|
|
|
- def get_topology_inference_engines(self) -> List[str]:
|
|
|
|
|
|
+ def get_topology_inference_engines(self) -> List[List[str]]:
|
|
|
|
+ logger.error(f'topology_inference_engines_pool: {self.topology_inference_engines_pool}')
|
|
return self.topology_inference_engines_pool
|
|
return self.topology_inference_engines_pool
|
|
|
|
|
|
async def process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
|
|
async def process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
|
|
@@ -365,20 +379,15 @@ class StandardNode(Node):
|
|
async def select_best_inference_engine(self):
|
|
async def select_best_inference_engine(self):
|
|
supported_engines = self.get_supported_inference_engines()
|
|
supported_engines = self.get_supported_inference_engines()
|
|
await self.broadcast_supported_engines(supported_engines)
|
|
await self.broadcast_supported_engines(supported_engines)
|
|
- logger.error('ABOVE and ALL')
|
|
|
|
- logger.error("Topology inference engines pool: %s", self.get_topology_inference_engines())
|
|
|
|
|
|
+ logger.error("Topology inference engines pool: %s", len(self.get_topology_inference_engines()))
|
|
logger.error(f'result:{self.get_topology_inference_engines()}')
|
|
logger.error(f'result:{self.get_topology_inference_engines()}')
|
|
- if self.get_topology_inference_engines():
|
|
|
|
- logger.info("Topology inference engines pool: %s", self.get_topology_inference_engines())
|
|
|
|
- topology_inference_engines_pool = self.get_topology_inference_engines()
|
|
|
|
- if any("tinygrad" in engines and len(engines) == 1 for engines in topology_inference_engines_pool):
|
|
|
|
- return "tinygrad"
|
|
|
|
- common_engine_across_peers = set.intersection(*topology_inference_engines_pool)
|
|
|
|
- if "mlx" in common_engine_across_peers:
|
|
|
|
- return "mlx"
|
|
|
|
|
|
+ if len(self.get_topology_inference_engines()):
|
|
|
|
+ if any(len(engines) == 1 and "tinygrad" in engines for engines in self.get_topology_inference_engines()):
|
|
|
|
+ logger.info("Found node with only tinygrad, using tinygrad on all nodes")
|
|
|
|
+ self.inference_engine = get_inference_engine("tinygrad", self.shard_downloader)
|
|
else:
|
|
else:
|
|
- raise ValueError("No compatible inference engine found across all nodes")
|
|
|
|
-
|
|
|
|
|
|
+ logger.info("all nodes can use mlx, using mlx for inference")
|
|
|
|
+ self.inference_engine = get_inference_engine("mlx", self.shard_downloader)
|
|
|
|
|
|
async def periodic_topology_collection(self, interval: int):
|
|
async def periodic_topology_collection(self, interval: int):
|
|
while True:
|
|
while True:
|
|
@@ -462,7 +471,6 @@ class StandardNode(Node):
|
|
|
|
|
|
async def send_status_to_peer(peer):
|
|
async def send_status_to_peer(peer):
|
|
try:
|
|
try:
|
|
- status_dict = json.loads(status)
|
|
|
|
await asyncio.wait_for(peer.send_opaque_status(request_id, status), timeout=15.0)
|
|
await asyncio.wait_for(peer.send_opaque_status(request_id, status), timeout=15.0)
|
|
except asyncio.TimeoutError:
|
|
except asyncio.TimeoutError:
|
|
print(f"Timeout sending opaque status to {peer.id()}")
|
|
print(f"Timeout sending opaque status to {peer.id()}")
|