|
@@ -168,15 +168,11 @@ async def select_best_inference_engine(node: StandardNode):
|
|
|
continue
|
|
|
if any("tinygrad" in engines and len(engines) == 1 for engines in all_supported_engines):
|
|
|
return "tinygrad"
|
|
|
- common_engines_across_peers = set.intersection(*all_supported_engines)
|
|
|
- with open('check_engines.txt', 'w') as f:
|
|
|
- f.write(common_engines_across_peers)
|
|
|
- f.close()
|
|
|
- print(f'common_engines_across_peers:{common_engines_across_peers}')
|
|
|
- if "mlx" in common_engines_across_peers:
|
|
|
+ common_engine_across_peers = set.intersection(*all_supported_engines)
|
|
|
+ if "mlx" in common_engine_across_peers:
|
|
|
print('mlx')
|
|
|
return "mlx"
|
|
|
- elif "tinygrad" in common_engines_across_peers:
|
|
|
+ elif "tinygrad" in common_engine_across_peers:
|
|
|
return "tinygrad"
|
|
|
else:
|
|
|
raise ValueError("No compatible inference engine found across all nodes")
|
|
@@ -221,6 +217,9 @@ async def main():
|
|
|
loop.add_signal_handler(s, handle_exit)
|
|
|
|
|
|
await node.start(wait_for_peers=args.wait_for_peers)
|
|
|
+ if len(node.peers) > 1:
|
|
|
+ compatible_engine = await select_best_inference_engine(node)
|
|
|
+ node.inference_engine = get_inference_engine(compatible_engine, shard_downloader)
|
|
|
|
|
|
if args.command == "run" or args.run_model:
|
|
|
model_name = args.model_name or args.run_model
|