|
@@ -2,6 +2,8 @@ import argparse
|
|
|
import asyncio
|
|
|
import signal
|
|
|
import json
|
|
|
+import logging
|
|
|
+import seqlog
|
|
|
import time
|
|
|
import traceback
|
|
|
import uuid
|
|
@@ -51,6 +53,13 @@ args = parser.parse_args()
|
|
|
|
|
|
print_yellow_exo()
|
|
|
|
|
|
+seqlog.log_to_seq(
|
|
|
+ server_url = 'http://localhost:5341',
|
|
|
+ level = logging.DEBUG
|
|
|
+)
|
|
|
+logger = logging.getLogger('exo_logger')
|
|
|
+logger.setLevel(logging.DEBUG)
|
|
|
+
|
|
|
system_info = get_system_info()
|
|
|
print(f"Detected system: {system_info}")
|
|
|
|
|
@@ -145,37 +154,18 @@ async def shutdown(signal, loop):
|
|
|
loop.stop()
|
|
|
|
|
|
async def select_best_inference_engine(node: StandardNode):
|
|
|
- supported_engines = set(node.get_supported_inference_engines())
|
|
|
+ supported_engines = node.get_supported_inference_engines()
|
|
|
await node.broadcast_supported_engines(supported_engines)
|
|
|
- num_peers = len(node.peers)
|
|
|
- all_peers_responded = asyncio.Event()
|
|
|
- def check_all_responses():
|
|
|
- if len(node.received_opaque_statuses) >= num_peers:
|
|
|
- all_peers_responded.set()
|
|
|
- node.on_opaque_status.register("engine_selection").on_next(lambda *args: check_all_responses())
|
|
|
- try:
|
|
|
- await asyncio.wait_for(all_peers_responded.wait(), timeout=10.0)
|
|
|
- except asyncio.TimeoutError:
|
|
|
- print("Timed out waiting for peer nodes to respond.")
|
|
|
- node.on_opaque_status.unregister("engine_selection")
|
|
|
- all_supported_engines = [supported_engines]
|
|
|
- for peer_id, status in node.received_opaque_statuses:
|
|
|
- try:
|
|
|
- status_data = json.loads(status)
|
|
|
- if status_data.get("type") == "supported_inference_engines":
|
|
|
- all_supported_engines.append(set(status_data.get("engines", [])))
|
|
|
- except json.JSONDecodeError:
|
|
|
- continue
|
|
|
- if any("tinygrad" in engines and len(engines) == 1 for engines in all_supported_engines):
|
|
|
- return "tinygrad"
|
|
|
- common_engine_across_peers = set.intersection(*all_supported_engines)
|
|
|
- if "mlx" in common_engine_across_peers:
|
|
|
- print('mlx')
|
|
|
- return "mlx"
|
|
|
- elif "tinygrad" in common_engine_across_peers:
|
|
|
- return "tinygrad"
|
|
|
- else:
|
|
|
- raise ValueError("No compatible inference engine found across all nodes")
|
|
|
+
|
|
|
+ if node.get_topology_inference_engines():
|
|
|
+ topology_inference_engines_pool = node.get_topology_inference_engines()
|
|
|
+ if any("tinygrad" in engines and len(engines) == 1 for engines in topology_inference_engines_pool):
|
|
|
+ return "tinygrad"
|
|
|
+ common_engine_across_peers = set.intersection(*topology_inference_engines_pool)
|
|
|
+ if "mlx" in common_engine_across_peers:
|
|
|
+ return "mlx"
|
|
|
+ else:
|
|
|
+ raise ValueError("No compatible inference engine found across all nodes")
|
|
|
|
|
|
|
|
|
async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
|
|
@@ -217,9 +207,8 @@ async def main():
|
|
|
loop.add_signal_handler(s, handle_exit)
|
|
|
|
|
|
await node.start(wait_for_peers=args.wait_for_peers)
|
|
|
- if len(node.peers) > 1:
|
|
|
- compatible_engine = await select_best_inference_engine(node)
|
|
|
- node.inference_engine = get_inference_engine(compatible_engine, shard_downloader)
|
|
|
+ await select_best_inference_engine(node)
|
|
|
+
|
|
|
|
|
|
if args.command == "run" or args.run_model:
|
|
|
model_name = args.model_name or args.run_model
|