|
@@ -144,6 +144,43 @@ async def shutdown(signal, loop):
|
|
|
await server.stop()
|
|
|
loop.stop()
|
|
|
|
|
|
+async def select_best_inference_engine(node: StandardNode):
|
|
|
+ supported_engines = set(node.get_supported_inference_engines())
|
|
|
+ await node.broadcast_supported_engines(supported_engines)
|
|
|
+ num_peers = len(node.peers)
|
|
|
+ all_peers_responded = asyncio.Event()
|
|
|
+ def check_all_responses():
|
|
|
+ if len(node.received_opaque_statuses) >= num_peers:
|
|
|
+ all_peers_responded.set()
|
|
|
+ node.on_opaque_status.register("engine_selection").on_next(lambda *args: check_all_responses())
|
|
|
+ try:
|
|
|
+ await asyncio.wait_for(all_peers_responded.wait(), timeout=10.0)
|
|
|
+ except asyncio.TimeoutError:
|
|
|
+ print("Timed out waiting for peer nodes to respond.")
|
|
|
+ node.on_opaque_status.unregister("engine_selection")
|
|
|
+ all_supported_engines = [supported_engines]
|
|
|
+ for peer_id, status in node.received_opaque_statuses:
|
|
|
+ try:
|
|
|
+ status_data = json.loads(status)
|
|
|
+ if status_data.get("type") == "supported_inference_engines":
|
|
|
+ all_supported_engines.append(set(status_data.get("engines", [])))
|
|
|
+ except json.JSONDecodeError:
|
|
|
+ continue
|
|
|
+ if any("tinygrad" in engines and len(engines) == 1 for engines in all_supported_engines):
|
|
|
+ return "tinygrad"
|
|
|
+ common_engines_across_peers = set.intersection(*all_supported_engines)
|
|
|
+ with open('check_engines.txt', 'w') as f:
|
|
|
+ f.write(common_engines_across_peers)
|
|
|
+ f.close()
|
|
|
+ print(f'common_engines_across_peers:{common_engines_across_peers}')
|
|
|
+ if "mlx" in common_engines_across_peers:
|
|
|
+ print('mlx')
|
|
|
+ return "mlx"
|
|
|
+ elif "tinygrad" in common_engines_across_peers:
|
|
|
+ return "tinygrad"
|
|
|
+ else:
|
|
|
+ raise ValueError("No compatible inference engine found across all nodes")
|
|
|
+
|
|
|
|
|
|
async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
|
|
|
shard = model_base_shards.get(model_name, {}).get(inference_engine.__class__.__name__)
|