Browse Source

changes to inference engine

josh 8 months ago
parent
commit
6dd2f7ab89
2 changed files with 7 additions and 9 deletions
  1. 6 7
      exo/main.py
  2. 1 2
      exo/orchestration/standard_node.py

+ 6 - 7
exo/main.py

@@ -168,15 +168,11 @@ async def select_best_inference_engine(node: StandardNode):
           continue
   if any("tinygrad" in engines and len(engines) == 1 for engines in all_supported_engines):
       return "tinygrad"
-  common_engines_across_peers = set.intersection(*all_supported_engines)
-  with open('check_engines.txt', 'w') as f:
-    f.write(common_engines_across_peers)
-    f.close()
-  print(f'common_engines_across_peers:{common_engines_across_peers}')
-  if "mlx" in common_engines_across_peers:
+  common_engine_across_peers = set.intersection(*all_supported_engines)
+  if "mlx" in common_engine_across_peers:
       print('mlx')
       return "mlx"
-  elif "tinygrad" in common_engines_across_peers:
+  elif "tinygrad" in common_engine_across_peers:
       return "tinygrad"
   else:
       raise ValueError("No compatible inference engine found across all nodes")
@@ -221,6 +217,9 @@ async def main():
     loop.add_signal_handler(s, handle_exit)
 
   await node.start(wait_for_peers=args.wait_for_peers)
+  if len(node.peers) > 1:
+    compatible_engine = await select_best_inference_engine(node)
+    node.inference_engine = get_inference_engine(compatible_engine, shard_downloader)
 
   if args.command == "run" or args.run_model:
     model_name = args.model_name or args.run_model

+ 1 - 2
exo/orchestration/standard_node.py

@@ -84,8 +84,7 @@ class StandardNode(Node):
       supported_engines.append('tinygrad')
     return supported_engines
 
-  async def broadcast_supported_engines(self):
-    supported_engines = self.get_supported_inference_engines()
+  async def broadcast_supported_engines(self, supported_engines: List):
     await self.broadcast_opaque_status("", json.dumps({
       "type": "supported_inference_engines",
       "node_id": self.id,