Browse Source

more robust health checks

Alex Cheema 9 months ago
parent
commit
db3c603d73
2 changed files with 11 additions and 16 deletions
  1. 0 1
      exo/networking/grpc/grpc_peer_handle.py
  2. 11 15
      exo/networking/udp_discovery.py

+ 0 - 1
exo/networking/grpc/grpc_peer_handle.py

@@ -3,7 +3,6 @@ import numpy as np
 import asyncio
 from typing import Optional, Tuple, List
 
-# These would be generated from the .proto file
 from . import node_service_pb2
 from . import node_service_pb2_grpc
 

+ 11 - 15
exo/networking/udp_discovery.py

@@ -53,7 +53,7 @@ class UDPDiscovery(Discovery):
     self.broadcast_interval = broadcast_interval
     self.discovery_timeout = discovery_timeout
     self.device_capabilities = device_capabilities
-    self.known_peers: Dict[str, Tuple[PeerHandle, float, float]] = {}
+    self.known_peers: Dict[str, Tuple[PeerHandle, float, float, bool]] = {}
     self.broadcast_task = None
     self.listen_task = None
     self.cleanup_task = None
@@ -76,7 +76,7 @@ class UDPDiscovery(Discovery):
       while len(self.known_peers) < wait_for_peers:
         if DEBUG_DISCOVERY >= 2: print(f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers...")
         await asyncio.sleep(0.1)
-    return [peer_handle for peer_handle, _, _ in self.known_peers.values()]
+    return [peer_handle for peer_handle, _, _, last_is_healthy in self.known_peers.values() if last_is_healthy]
 
   async def task_broadcast_presence(self):
     message = json.dumps({
@@ -144,16 +144,13 @@ class UDPDiscovery(Discovery):
       new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
 
       # Check if the new peer is healthy before adding
-      if await new_peer_handle.health_check():
-        if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
-          if DEBUG >= 1: print(
-            f"Adding {peer_id=} at {peer_host}:{peer_port}. Replace existing peer_id: {peer_id in self.known_peers}")
-          self.known_peers[peer_id] = (new_peer_handle, time.time(), time.time())
-        else:
-          # Update last seen time for existing peer
-          self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time())
+      if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
+        if DEBUG >= 1: print(
+          f"Adding {peer_id=} at {peer_host}:{peer_port}. Replace existing peer_id: {peer_id in self.known_peers}")
+        self.known_peers[peer_id] = (new_peer_handle, time.time(), time.time(), await new_peer_handle.health_check())
       else:
-        if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} failed health check. Not adding.")
+        # Update last seen time for existing peer
+        self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time(), await new_peer_handle.health_check())
 
   async def task_listen_for_peers(self):
     await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message),
@@ -165,10 +162,9 @@ class UDPDiscovery(Discovery):
       try:
         current_time = time.time()
         peers_to_remove = []
-        for peer_id, (peer_handle, connected_at, last_seen) in self.known_peers.items():
-          if (not await peer_handle.is_connected() and current_time - connected_at > self.discovery_timeout) or \
-             current_time - last_seen > self.discovery_timeout or \
-             not await peer_handle.health_check():
+        for peer_id, (peer_handle, connected_at, last_seen, last_is_healthy) in self.known_peers.items():
+          if ((not await peer_handle.is_connected() or not last_is_healthy or not await peer_handle.health_check()) and current_time - connected_at > self.discovery_timeout) or \
+             current_time - last_seen > self.discovery_timeout:
             peers_to_remove.append(peer_id)
 
         if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, health_check={await peer_handle.health_check()}, {connected_at=}, {last_seen=}" for peer_handle, connected_at, last_seen in self.known_peers.values()})