Browse Source

changes to broadcast func

josh 8 months ago
parent
commit
3908b97a2b
2 changed files with 53 additions and 0 deletions
  1. 37 0
      exo/main.py
  2. 16 0
      exo/orchestration/standard_node.py

+ 37 - 0
exo/main.py

@@ -144,6 +144,43 @@ async def shutdown(signal, loop):
   await server.stop()
   await server.stop()
   loop.stop()
   loop.stop()
 
 
+async def select_best_inference_engine(node: StandardNode):
+  supported_engines = set(node.get_supported_inference_engines())
+  await node.broadcast_supported_engines(supported_engines)
+  num_peers = len(node.peers)
+  all_peers_responded = asyncio.Event()
+  def check_all_responses():
+      if len(node.received_opaque_statuses) >= num_peers:
+          all_peers_responded.set()
+  node.on_opaque_status.register("engine_selection").on_next(lambda *args: check_all_responses())
+  try:
+      await asyncio.wait_for(all_peers_responded.wait(), timeout=10.0)
+  except asyncio.TimeoutError:
+      print("Timed out waiting for peer nodes to respond.")
+  node.on_opaque_status.unregister("engine_selection")
+  all_supported_engines = [supported_engines]
+  for peer_id, status in node.received_opaque_statuses:
+      try:
+          status_data = json.loads(status)
+          if status_data.get("type") == "supported_inference_engines":
+              all_supported_engines.append(set(status_data.get("engines", [])))
+      except json.JSONDecodeError:
+          continue
+  if any("tinygrad" in engines and len(engines) == 1 for engines in all_supported_engines):
+      return "tinygrad"
+  common_engines_across_peers = set.intersection(*all_supported_engines)
+  with open('check_engines.txt', 'w') as f:
+    f.write(common_engines_across_peers)
+    f.close()
+  print(f'common_engines_across_peers:{common_engines_across_peers}')
+  if "mlx" in common_engines_across_peers:
+      print('mlx')
+      return "mlx"
+  elif "tinygrad" in common_engines_across_peers:
+      return "tinygrad"
+  else:
+      raise ValueError("No compatible inference engine found across all nodes")
+
 
 
 async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
 async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
   shard = model_base_shards.get(model_name, {}).get(inference_engine.__class__.__name__)
   shard = model_base_shards.get(model_name, {}).get(inference_engine.__class__.__name__)

+ 16 - 0
exo/orchestration/standard_node.py

@@ -76,6 +76,22 @@ class StandardNode(Node):
       if DEBUG >= 1: print(f"Error updating visualization: {e}")
       if DEBUG >= 1: print(f"Error updating visualization: {e}")
       if DEBUG >= 1: traceback.print_exc()
       if DEBUG >= 1: traceback.print_exc()
 
 
+  def get_supported_inference_engines(self):
+    supported_engines = []
+    if self.inferenceEngine == 'mlx':
+      supported_engines.extend('mlx', 'tinygrad')
+    else:
+      supported_engines.append('tinygrad')
+    return supported_engines
+
+  async def broadcast_supported_engines(self):
+    supported_engines = self.get_supported_inference_engines()
+    await self.broadcast_opaque_status("", json.dumps({
+      "type": "supported_inference_engines",
+      "node_id": self.id, 
+      "engines": supported_engines
+    }))
+
   async def process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
   async def process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
     shard = self.get_current_shard(base_shard)
     shard = self.get_current_shard(base_shard)
     asyncio.create_task(
     asyncio.create_task(