josh 8 månader sedan
förälder
incheckning
da52357321
3 ändrade filer med 26 tillägg och 23 borttagningar
  1. 1 19
      exo/main.py
  2. 5 1
      exo/networking/udp/udp_discovery.py
  3. 20 3
      exo/orchestration/standard_node.py

+ 1 - 19
exo/main.py

@@ -153,24 +153,6 @@ async def shutdown(signal, loop):
   await server.stop()
   await server.stop()
   loop.stop()
   loop.stop()
 
 
-async def select_best_inference_engine(node: StandardNode):
-  supported_engines = node.get_supported_inference_engines()
-  await node.broadcast_supported_engines(supported_engines)
-  logger.error('ABOVE and ALL')
-  logger.error("Topology inference engines pool: %s", node.get_topology_inference_engines())
-  logger.error(f'result:{node.get_topology_inference_engines()}')
-  if node.get_topology_inference_engines():
-    logger.info("Topology inference engines pool: %s", node.get_topology_inference_engines())
-    topology_inference_engines_pool = node.get_topology_inference_engines()
-    if any("tinygrad" in engines and len(engines) == 1 for engines in topology_inference_engines_pool):
-        return "tinygrad"
-    common_engine_across_peers = set.intersection(*topology_inference_engines_pool)
-    if "mlx" in common_engine_across_peers:
-        return "mlx"
-    else:
-        raise ValueError("No compatible inference engine found across all nodes")
-
-
 async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
 async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
   shard = model_base_shards.get(model_name, {}).get(inference_engine.__class__.__name__)
   shard = model_base_shards.get(model_name, {}).get(inference_engine.__class__.__name__)
   if not shard:
   if not shard:
@@ -210,7 +192,7 @@ async def main():
     loop.add_signal_handler(s, handle_exit)
     loop.add_signal_handler(s, handle_exit)
 
 
   await node.start(wait_for_peers=args.wait_for_peers)
   await node.start(wait_for_peers=args.wait_for_peers)
-  await select_best_inference_engine(node)
+
   if args.command == "run" or args.run_model:
   if args.command == "run" or args.run_model:
     model_name = args.model_name or args.run_model
     model_name = args.model_name or args.run_model
     if not model_name:
     if not model_name:

+ 5 - 1
exo/networking/udp/udp_discovery.py

@@ -159,7 +159,11 @@ class UDPDiscovery(Discovery):
           if peer_id in self.known_peers: del self.known_peers[peer_id]
           if peer_id in self.known_peers: del self.known_peers[peer_id]
           return
           return
         if peer_id in self.known_peers: self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time(), peer_prio)
         if peer_id in self.known_peers: self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time(), peer_prio)
-  
+    if message["type"] == "supported_inference_engines":
+      logger.error(f'supported_inference_engines: {message}')
+      peer_id = message["node_id"]
+      engines = message["engines"]
+      if peer_id in self.known_peers: self.known_peers[peer_id][0].topology_inference_engines_pool.append(engines)
   async def task_listen_for_peers(self):
   async def task_listen_for_peers(self):
     await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message),
     await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message),
                                                             local_addr=("0.0.0.0", self.listen_port))
                                                             local_addr=("0.0.0.0", self.listen_port))

+ 20 - 3
exo/orchestration/standard_node.py

@@ -362,6 +362,24 @@ class StandardNode(Node):
     self.peers = next_peers
     self.peers = next_peers
     return len(peers_added) > 0 or len(peers_removed) > 0 or len(peers_updated) > 0
     return len(peers_added) > 0 or len(peers_removed) > 0 or len(peers_updated) > 0
 
 
+  async def select_best_inference_engine(self):
+    supported_engines = self.get_supported_inference_engines()
+    await self.broadcast_supported_engines(supported_engines)
+    logger.error('ABOVE and ALL')
+    logger.error("Topology inference engines pool: %s", self.get_topology_inference_engines())
+    logger.error(f'result:{self.get_topology_inference_engines()}')
+    if self.get_topology_inference_engines():
+      logger.info("Topology inference engines pool: %s", self.get_topology_inference_engines())
+      topology_inference_engines_pool = self.get_topology_inference_engines()
+      if any("tinygrad" in engines and len(engines) == 1 for engines in topology_inference_engines_pool):
+          return "tinygrad"
+      common_engine_across_peers = set.intersection(*topology_inference_engines_pool)
+      if "mlx" in common_engine_across_peers:
+          return "mlx"
+      else:
+          raise ValueError("No compatible inference engine found across all nodes")
+
+
   async def periodic_topology_collection(self, interval: int):
   async def periodic_topology_collection(self, interval: int):
     while True:
     while True:
       await asyncio.sleep(interval)
       await asyncio.sleep(interval)
@@ -369,7 +387,9 @@ class StandardNode(Node):
         did_peers_change = await self.update_peers()
         did_peers_change = await self.update_peers()
         if DEBUG >= 2: print(f"{did_peers_change=}")
         if DEBUG >= 2: print(f"{did_peers_change=}")
         if did_peers_change:
         if did_peers_change:
+          logger.error('peers changed, collecting topology and selecting best inference engine')
           await self.collect_topology()
           await self.collect_topology()
+          await self.select_best_inference_engine()
       except Exception as e:
       except Exception as e:
         print(f"Error collecting topology: {e}")
         print(f"Error collecting topology: {e}")
         traceback.print_exc()
         traceback.print_exc()
@@ -443,15 +463,12 @@ class StandardNode(Node):
     async def send_status_to_peer(peer):
     async def send_status_to_peer(peer):
       try:
       try:
         status_dict = json.loads(status)
         status_dict = json.loads(status)
-        if status_dict.get("type") == "supported_inference_engines":
-          logger.error(f'broadcasting_inference_engines: {status_dict}')
         await asyncio.wait_for(peer.send_opaque_status(request_id, status), timeout=15.0)
         await asyncio.wait_for(peer.send_opaque_status(request_id, status), timeout=15.0)
       except asyncio.TimeoutError:
       except asyncio.TimeoutError:
         print(f"Timeout sending opaque status to {peer.id()}")
         print(f"Timeout sending opaque status to {peer.id()}")
       except Exception as e:
       except Exception as e:
         print(f"Error sending opaque status to {peer.id()}: {e}")
         print(f"Error sending opaque status to {peer.id()}: {e}")
         traceback.print_exc()
         traceback.print_exc()
-
     await asyncio.gather(*[send_status_to_peer(peer) for peer in self.peers], return_exceptions=True)
     await asyncio.gather(*[send_status_to_peer(peer) for peer in self.peers], return_exceptions=True)
     # in the case of opaque status, we also want to receive our own opaque statuses
     # in the case of opaque status, we also want to receive our own opaque statuses
     self.on_opaque_status.trigger_all(request_id, status)
     self.on_opaque_status.trigger_all(request_id, status)