|
@@ -19,8 +19,6 @@ from exo.inference.inference_engine import get_inference_engine, InferenceEngine
|
|
from exo.download.hf.hf_shard_download import HFShardDownloader
|
|
from exo.download.hf.hf_shard_download import HFShardDownloader
|
|
import logging
|
|
import logging
|
|
|
|
|
|
-logger = logging.getLogger(__name__)
|
|
|
|
-
|
|
|
|
class StandardNode(Node):
|
|
class StandardNode(Node):
|
|
def __init__(
|
|
def __init__(
|
|
self,
|
|
self,
|
|
@@ -67,13 +65,9 @@ class StandardNode(Node):
|
|
try:
|
|
try:
|
|
status_data = json.loads(opaque_status)
|
|
status_data = json.loads(opaque_status)
|
|
if status_data.get("type", "") == "supported_inference_engines":
|
|
if status_data.get("type", "") == "supported_inference_engines":
|
|
- logger.error(f'supported_inference_engines: {status_data}')
|
|
|
|
- logger.error('inside on_status data')
|
|
|
|
node_id = status_data.get("node_id")
|
|
node_id = status_data.get("node_id")
|
|
engines = status_data.get("engines", [])
|
|
engines = status_data.get("engines", [])
|
|
- logger.error(f'engines: {engines}')
|
|
|
|
self.topology_inference_engines_pool.append(engines)
|
|
self.topology_inference_engines_pool.append(engines)
|
|
- logger.error(f'topology_inference_engines_pool: {self.topology_inference_engines_pool}')
|
|
|
|
if status_data.get("type", "") == "node_status":
|
|
if status_data.get("type", "") == "node_status":
|
|
if status_data.get("status", "").startswith("start_"):
|
|
if status_data.get("status", "").startswith("start_"):
|
|
self.current_topology.active_node_id = status_data.get("node_id")
|
|
self.current_topology.active_node_id = status_data.get("node_id")
|
|
@@ -106,12 +100,9 @@ class StandardNode(Node):
|
|
"node_id": self.id,
|
|
"node_id": self.id,
|
|
"engines": supported_engines_names
|
|
"engines": supported_engines_names
|
|
})
|
|
})
|
|
- logger.error(f'broadcast_supported_engines: {status_message}')
|
|
|
|
await self.broadcast_opaque_status("", status_message)
|
|
await self.broadcast_opaque_status("", status_message)
|
|
- logger.error(f'broadcast_supported_engines: done')
|
|
|
|
|
|
|
|
def get_topology_inference_engines(self) -> List[List[str]]:
|
|
def get_topology_inference_engines(self) -> List[List[str]]:
|
|
- logger.error(f'topology_inference_engines_pool: {self.topology_inference_engines_pool}')
|
|
|
|
return self.topology_inference_engines_pool
|
|
return self.topology_inference_engines_pool
|
|
|
|
|
|
async def process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
|
|
async def process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
|
|
@@ -379,8 +370,6 @@ class StandardNode(Node):
|
|
async def select_best_inference_engine(self):
|
|
async def select_best_inference_engine(self):
|
|
supported_engines = self.get_supported_inference_engines()
|
|
supported_engines = self.get_supported_inference_engines()
|
|
await self.broadcast_supported_engines(supported_engines)
|
|
await self.broadcast_supported_engines(supported_engines)
|
|
- logger.error("Topology inference engines pool: %s", len(self.get_topology_inference_engines()))
|
|
|
|
- logger.error(f'result:{self.get_topology_inference_engines()}')
|
|
|
|
if len(self.get_topology_inference_engines()):
|
|
if len(self.get_topology_inference_engines()):
|
|
if any(len(engines) == 1 and "tinygrad" in engines for engines in self.get_topology_inference_engines()):
|
|
if any(len(engines) == 1 and "tinygrad" in engines for engines in self.get_topology_inference_engines()):
|
|
logger.info("Found node with only tinygrad, using tinygrad on all nodes")
|
|
logger.info("Found node with only tinygrad, using tinygrad on all nodes")
|
|
@@ -396,7 +385,6 @@ class StandardNode(Node):
|
|
did_peers_change = await self.update_peers()
|
|
did_peers_change = await self.update_peers()
|
|
if DEBUG >= 2: print(f"{did_peers_change=}")
|
|
if DEBUG >= 2: print(f"{did_peers_change=}")
|
|
if did_peers_change:
|
|
if did_peers_change:
|
|
- logger.error('peers changed, collecting topology and selecting best inference engine')
|
|
|
|
await self.collect_topology()
|
|
await self.collect_topology()
|
|
await self.select_best_inference_engine()
|
|
await self.select_best_inference_engine()
|
|
except Exception as e:
|
|
except Exception as e:
|