|
@@ -362,6 +362,24 @@ class StandardNode(Node):
|
|
|
self.peers = next_peers
|
|
|
return len(peers_added) > 0 or len(peers_removed) > 0 or len(peers_updated) > 0
|
|
|
|
|
|
+ async def select_best_inference_engine(self):
|
|
|
+ supported_engines = self.get_supported_inference_engines()
|
|
|
+ await self.broadcast_supported_engines(supported_engines)
|
|
|
+ logger.error('ABOVE and ALL')
|
|
|
+ logger.error("Topology inference engines pool: %s", self.get_topology_inference_engines())
|
|
|
+ logger.error(f'result:{self.get_topology_inference_engines()}')
|
|
|
+ if self.get_topology_inference_engines():
|
|
|
+ logger.info("Topology inference engines pool: %s", self.get_topology_inference_engines())
|
|
|
+ topology_inference_engines_pool = self.get_topology_inference_engines()
|
|
|
+ if any("tinygrad" in engines and len(engines) == 1 for engines in topology_inference_engines_pool):
|
|
|
+ return "tinygrad"
|
|
|
+ common_engine_across_peers = set.intersection(*topology_inference_engines_pool)
|
|
|
+ if "mlx" in common_engine_across_peers:
|
|
|
+ return "mlx"
|
|
|
+ else:
|
|
|
+ raise ValueError("No compatible inference engine found across all nodes")
|
|
|
+
|
|
|
+
|
|
|
async def periodic_topology_collection(self, interval: int):
|
|
|
while True:
|
|
|
await asyncio.sleep(interval)
|
|
@@ -369,7 +387,9 @@ class StandardNode(Node):
|
|
|
did_peers_change = await self.update_peers()
|
|
|
if DEBUG >= 2: print(f"{did_peers_change=}")
|
|
|
if did_peers_change:
|
|
|
+ logger.error('peers changed, collecting topology and selecting best inference engine')
|
|
|
await self.collect_topology()
|
|
|
+ await self.select_best_inference_engine()
|
|
|
except Exception as e:
|
|
|
print(f"Error collecting topology: {e}")
|
|
|
traceback.print_exc()
|
|
@@ -443,15 +463,12 @@ class StandardNode(Node):
|
|
|
async def send_status_to_peer(peer):
|
|
|
try:
|
|
|
status_dict = json.loads(status)
|
|
|
- if status_dict.get("type") == "supported_inference_engines":
|
|
|
- logger.error(f'broadcasting_inference_engines: {status_dict}')
|
|
|
await asyncio.wait_for(peer.send_opaque_status(request_id, status), timeout=15.0)
|
|
|
except asyncio.TimeoutError:
|
|
|
print(f"Timeout sending opaque status to {peer.id()}")
|
|
|
except Exception as e:
|
|
|
print(f"Error sending opaque status to {peer.id()}: {e}")
|
|
|
traceback.print_exc()
|
|
|
-
|
|
|
await asyncio.gather(*[send_status_to_peer(peer) for peer in self.peers], return_exceptions=True)
|
|
|
# in the case of opaque status, we also want to receive our own opaque statuses
|
|
|
self.on_opaque_status.trigger_all(request_id, status)
|