Bläddra i källkod

Fix issue where offline node cannot detect online node over thunderbolt due to the online node not broadcasting over the Thunderbolt bridge when non-bridge alternatives exist (ie. wifi or eth)
Extract broadcast logic into a DatagramProtocol to follow similar pattern as the listener

Mark Van Aken 7 månader sedan
förälder
incheckning
27bf5069ab
1 ändrade filer med 52 tillägg och 28 borttagningar
  1. 52 28
      exo/networking/udp_discovery.py

+ 52 - 28
exo/networking/udp_discovery.py

@@ -7,7 +7,7 @@ from typing import List, Dict, Callable, Tuple, Coroutine
 from .discovery import Discovery
 from .peer_handle import PeerHandle
 from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
-from exo.helpers import DEBUG, DEBUG_DISCOVERY
+from exo.helpers import DEBUG, DEBUG_DISCOVERY, get_all_ip_addresses
 
 class ListenProtocol(asyncio.DatagramProtocol):
   def __init__(self, on_message: Callable[[bytes, Tuple[str, int]], Coroutine]):
@@ -22,6 +22,17 @@ class ListenProtocol(asyncio.DatagramProtocol):
     asyncio.create_task(self.on_message(data, addr))
 
 
+class BroadcastProtocol(asyncio.DatagramProtocol):
+  def __init__(self, message: str, broadcast_port: int):
+    self.message = message
+    self.broadcast_port = broadcast_port
+
+  def connection_made(self, transport):
+    sock = transport.get_extra_info("socket")
+    sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
+    transport.sendto(self.message.encode("utf-8"), ("<broadcast>", self.broadcast_port))
+
+
 class UDPDiscovery(Discovery):
   def __init__(
     self,
@@ -68,31 +79,40 @@ class UDPDiscovery(Discovery):
     return [peer_handle for peer_handle, _, _ in self.known_peers.values()]
 
   async def task_broadcast_presence(self):
+    message = json.dumps({
+      "type": "discovery",
+      "node_id": self.node_id,
+      "grpc_port": self.node_port,
+      "device_capabilities": self.device_capabilities.to_dict(),
+    })
+
+    if DEBUG_DISCOVERY >= 2:
+      print("Starting task_broadcast_presence...")
+      print(f"\nBroadcast message: {message}")
+
     while True:
-      try:
-        message = json.dumps({
-          "type": "discovery",
-          "node_id": self.node_id,
-          "grpc_port": self.node_port,
-          "device_capabilities": self.device_capabilities.to_dict(),
-        }).encode("utf-8")
-        if DEBUG_DISCOVERY >= 3: print(f"Broadcast presence: {message}")
-
-        transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(lambda: asyncio.DatagramProtocol(), local_addr=("0.0.0.0", 0), family=socket.AF_INET)
-        sock = transport.get_extra_info("socket")
-        sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
-        transport.sendto(message, ("<broadcast>", self.broadcast_port))
-      except Exception as e:
-        print(f"Error in broadcast presence: {e}")
-        print(traceback.format_exc())
-      finally:
-        if transport:
-          try:
-            transport.close()
-          except:
-            if DEBUG_DISCOVERY >= 2: print(f"Error closing transport: {e}")
-            if DEBUG_DISCOVERY >= 2: traceback.print_exc()
-        await asyncio.sleep(self.broadcast_interval)
+      # Explicitly broadcasting on all assigned ips since broadcasting on `0.0.0.0` on MacOS does not broadcast over
+      # the Thunderbolt bridge when other connection modalities exist such as WiFi or Ethernet
+      for addr in get_all_ip_addresses():
+        transport = None
+        try:
+          transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
+            lambda: BroadcastProtocol(message, self.broadcast_port),
+            local_addr=(addr, 0),
+            family=socket.AF_INET
+          )
+          if DEBUG_DISCOVERY >= 3:
+            print(f"Broadcasting presence at ({addr})")
+        except Exception as e:
+          print(f"Error in broadcast presence ({addr}): {e}")
+        finally:
+          if transport:
+            try:
+              transport.close()
+            except Exception as e:
+              if DEBUG_DISCOVERY >= 2: print(f"Error closing transport: {e}")
+              if DEBUG_DISCOVERY >= 2: traceback.print_exc()
+      await asyncio.sleep(self.broadcast_interval)
 
   async def on_listen_message(self, data, addr):
     if not data:
@@ -120,7 +140,8 @@ class UDPDiscovery(Discovery):
       peer_port = message["grpc_port"]
       device_capabilities = DeviceCapabilities(**message["device_capabilities"])
       if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
-        if DEBUG >= 1: print(f"Adding {peer_id=} at {peer_host}:{peer_port}. Replace existing peer_id: {peer_id in self.known_peers}")
+        if DEBUG >= 1: print(
+          f"Adding {peer_id=} at {peer_host}:{peer_port}. Replace existing peer_id: {peer_id in self.known_peers}")
         self.known_peers[peer_id] = (
           self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities),
           time.time(),
@@ -129,7 +150,8 @@ class UDPDiscovery(Discovery):
       self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time())
 
   async def task_listen_for_peers(self):
-    await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message), local_addr=("0.0.0.0", self.listen_port))
+    await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message),
+                                                            local_addr=("0.0.0.0", self.listen_port))
     if DEBUG_DISCOVERY >= 2: print("Started listen task")
 
   async def task_cleanup_peers(self):
@@ -140,7 +162,9 @@ class UDPDiscovery(Discovery):
           peer_handle.id() for peer_handle, connected_at, last_seen in self.known_peers.values()
           if (not await peer_handle.is_connected() and current_time - connected_at > self.discovery_timeout) or current_time - last_seen > self.discovery_timeout
         ]
-        if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}" for peer_handle, connected_at, last_seen in self.known_peers.values()})
+        if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {
+          peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}" for
+          peer_handle, connected_at, last_seen in self.known_peers.values()})
         for peer_id in peers_to_remove:
           if peer_id in self.known_peers: del self.known_peers[peer_id]
           if DEBUG_DISCOVERY >= 2: print(f"Removed peer {peer_id} due to inactivity.")