浏览代码

more robust udp broadcast

Alex Cheema 8 月之前
父节点
当前提交
198308b1eb
共有 1 个文件被更改,包括 22 次插入5 次删除
  1. 22 5
      exo/networking/udp/udp_discovery.py

+ 22 - 5
exo/networking/udp/udp_discovery.py

@@ -23,15 +23,29 @@ class ListenProtocol(asyncio.DatagramProtocol):
     asyncio.create_task(self.on_message(data, addr))
 
 
+def get_broadcast_address(ip_addr: str) -> str:
+  try:
+    # Split IP into octets and create broadcast address for the subnet
+    ip_parts = ip_addr.split('.')
+    return f"{ip_parts[0]}.{ip_parts[1]}.{ip_parts[2]}.255"
+  except:
+    return "255.255.255.255"
+
+
 class BroadcastProtocol(asyncio.DatagramProtocol):
-  def __init__(self, message: str, broadcast_port: int):
+  def __init__(self, message: str, broadcast_port: int, source_ip: str):
     self.message = message
     self.broadcast_port = broadcast_port
+    self.source_ip = source_ip
 
   def connection_made(self, transport):
     sock = transport.get_extra_info("socket")
     sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
-    transport.sendto(self.message.encode("utf-8"), ("255.255.255.255", self.broadcast_port))
+    # Try both subnet-specific and global broadcast
+    broadcast_addr = get_broadcast_address(self.source_ip)
+    transport.sendto(self.message.encode("utf-8"), (broadcast_addr, self.broadcast_port))
+    if broadcast_addr != "255.255.255.255":
+      transport.sendto(self.message.encode("utf-8"), ("255.255.255.255", self.broadcast_port))
 
 
 class UDPDiscovery(Discovery):
@@ -99,14 +113,17 @@ class UDPDiscovery(Discovery):
 
         transport = None
         try:
-          # Create socket with explicit broadcast permission
           sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
           sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
+          sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+          try:
+            sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
+          except AttributeError:
+            pass
           sock.bind((addr, 0))
           
-          # Create transport with the pre-configured socket
           transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
-            lambda: BroadcastProtocol(message, self.broadcast_port),
+            lambda: BroadcastProtocol(message, self.broadcast_port, addr),
             sock=sock
           )
         except Exception as e: