瀏覽代碼

fix the race condition in cleanup peers and run the peer checks concurrently. fixes #308

Alex Cheema 10 月之前
父節點
當前提交
e80ee60760
共有 2 個文件被更改,包括 60 次插入16 次删除
  1. 32 8
      exo/networking/tailscale/tailscale_discovery.py
  2. 28 8
      exo/networking/udp/udp_discovery.py

+ 32 - 8
exo/networking/tailscale/tailscale_discovery.py

@@ -6,7 +6,7 @@ from exo.networking.discovery import Discovery
 from exo.networking.peer_handle import PeerHandle
 from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
 from exo.helpers import DEBUG, DEBUG_DISCOVERY
-from .tailscale_helpers import get_device_id, update_device_attributes, get_device_attributes, update_device_attributes, get_tailscale_devices, Device
+from .tailscale_helpers import get_device_id, update_device_attributes, get_device_attributes, get_tailscale_devices, Device
 
 class TailscaleDiscovery(Discovery):
   def __init__(
@@ -133,16 +133,40 @@ class TailscaleDiscovery(Discovery):
     while True:
       try:
         current_time = time.time()
-        peers_to_remove = [
-          peer_handle.id() for peer_handle, connected_at, last_seen in self.known_peers.values()
-          if (not await peer_handle.is_connected() and current_time - connected_at > self.discovery_timeout) or current_time - last_seen > self.discovery_timeout or not await peer_handle.health_check()
-        ]
-        if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}, health_check={await peer_handle.health_check()}" for peer_handle, connected_at, last_seen in self.known_peers.values()})
+        peers_to_remove = []
+
+        peer_ids = list(self.known_peers.keys())
+        results = await asyncio.gather(*[self.check_peer(peer_id, current_time) for peer_id in peer_ids], return_exceptions=True)
+
+        for peer_id, should_remove in zip(peer_ids, results):
+          if should_remove: peers_to_remove.append(peer_id)
+
+        if DEBUG_DISCOVERY >= 2: print("Peer statuses:", { peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, health_check={await peer_handle.health_check()}, connected_at={connected_at}, last_seen={last_seen}" for peer_handle, connected_at, last_seen in self.known_peers.values() })
+
         for peer_id in peers_to_remove:
-          if peer_id in self.known_peers: del self.known_peers[peer_id]
-          if DEBUG_DISCOVERY >= 2: print(f"Removed peer {peer_id} due to inactivity.")
+          if peer_id in self.known_peers:
+            del self.known_peers[peer_id]
+            if DEBUG_DISCOVERY >= 2: print(f"Removed peer {peer_id} due to inactivity or failed health check.")
       except Exception as e:
         print(f"Error in cleanup peers: {e}")
         print(traceback.format_exc())
       finally:
         await asyncio.sleep(self.discovery_interval)
+
+  async def check_peer(self, peer_id: str, current_time: float) -> bool:
+    peer_handle, connected_at, last_seen = self.known_peers.get(peer_id, (None, None, None))
+    if peer_handle is None: return False
+
+    try:
+      is_connected = await peer_handle.is_connected()
+      health_ok = await peer_handle.health_check()
+    except Exception as e:
+      if DEBUG_DISCOVERY >= 2: print(f"Error checking peer {peer_id}: {e}")
+      return True
+
+    should_remove = (
+      (not is_connected and current_time - connected_at > self.discovery_timeout) or
+      (current_time - last_seen > self.discovery_timeout) or
+      (not health_ok)
+    )
+    return should_remove

+ 28 - 8
exo/networking/udp/udp_discovery.py

@@ -171,19 +171,39 @@ class UDPDiscovery(Discovery):
       try:
         current_time = time.time()
         peers_to_remove = []
-        for peer_id, (peer_handle, connected_at, last_seen, prio) in self.known_peers.items():
-          if (not await peer_handle.is_connected() and current_time - connected_at > self.discovery_timeout) or \
-             (current_time - last_seen > self.discovery_timeout) or \
-             (not await peer_handle.health_check()):
-            peers_to_remove.append(peer_id)
 
-        if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, health_check={await peer_handle.health_check()}, {connected_at=}, {last_seen=}, {prio=}" for peer_handle, connected_at, last_seen, prio in self.known_peers.values()})
+        peer_ids = list(self.known_peers.keys())
+        results = await asyncio.gather(*[self.check_peer(peer_id, current_time) for peer_id in peer_ids], return_exceptions=True)
+
+        for peer_id, should_remove in zip(peer_ids, results):
+          if should_remove: peers_to_remove.append(peer_id)
+
+        if DEBUG_DISCOVERY >= 2: print("Peer statuses:", { peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, health_check={await peer_handle.health_check()}, connected_at={connected_at}, last_seen={last_seen}, prio={prio}" for peer_handle, connected_at, last_seen, prio in self.known_peers.values() })
 
         for peer_id in peers_to_remove:
-          if peer_id in self.known_peers: del self.known_peers[peer_id]
-          if DEBUG_DISCOVERY >= 2: print(f"Removed peer {peer_id} due to inactivity or failed health check.")
+          if peer_id in self.known_peers:
+            del self.known_peers[peer_id]
+            if DEBUG_DISCOVERY >= 2: print(f"Removed peer {peer_id} due to inactivity or failed health check.")
       except Exception as e:
         print(f"Error in cleanup peers: {e}")
         print(traceback.format_exc())
       finally:
         await asyncio.sleep(self.broadcast_interval)
+
+  async def check_peer(self, peer_id: str, current_time: float) -> bool:
+    peer_handle, connected_at, last_seen, prio = self.known_peers.get(peer_id, (None, None, None, None))
+    if peer_handle is None: return False
+
+    try:
+      is_connected = await peer_handle.is_connected()
+      health_ok = await peer_handle.health_check()
+    except Exception as e:
+      if DEBUG_DISCOVERY >= 2: print(f"Error checking peer {peer_id}: {e}")
+      return True
+
+    should_remove = (
+      (not is_connected and current_time - connected_at > self.discovery_timeout) or
+      (current_time - last_seen > self.discovery_timeout) or
+      (not health_ok)
+    )
+    return should_remove