Browse Source

fix UDPDiscovery params, create a new transport every time we broadcast

Alex Cheema 11 months ago
parent
commit
3dd81a1e05
1 changed files with 21 additions and 22 deletions
  1. 21 22
      exo/networking/udp_discovery.py

+ 21 - 22
exo/networking/udp_discovery.py

@@ -3,13 +3,12 @@ import json
 import socket
 import time
 import traceback
-from typing import List, Dict, Callable, Tuple, Coroutine, Type
+from typing import List, Dict, Callable, Tuple, Coroutine
 from .discovery import Discovery
 from .peer_handle import PeerHandle
 from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
 from exo import DEBUG_DISCOVERY
 
-
 class ListenProtocol(asyncio.DatagramProtocol):
   def __init__(self, on_message: Callable[[bytes, Tuple[str, int]], Coroutine]):
     super().__init__()
@@ -29,30 +28,30 @@ class UDPDiscovery(Discovery):
     node_id: str,
     node_port: int,
     listen_port: int,
+    broadcast_port: int,
     create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
-    broadcast_port: int = None,
     broadcast_interval: int = 1,
-    device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
     discovery_timeout: int = 30,
+    device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
   ):
     self.node_id = node_id
     self.node_port = node_port
-    self.device_capabilities = device_capabilities
     self.listen_port = listen_port
-    self.broadcast_port = broadcast_port if broadcast_port is not None else listen_port
+    self.broadcast_port = broadcast_port
+    self.create_peer_handle = create_peer_handle
     self.broadcast_interval = broadcast_interval
+    self.discovery_timeout = discovery_timeout
+    self.device_capabilities = device_capabilities
     self.known_peers: Dict[str, Tuple[PeerHandle, float, float]] = {}
     self.broadcast_task = None
     self.listen_task = None
     self.cleanup_task = None
-    self.discovery_timeout = discovery_timeout
-    self.create_peer_handle = create_peer_handle
 
   async def start(self):
     self.device_capabilities = device_capabilities()
     self.broadcast_task = asyncio.create_task(self.task_broadcast_presence())
-    self.listen_task = asyncio.create_task(self.task_listen_for_peers())
-    self.cleanup_task = asyncio.create_task(self.task_cleanup_peers())
+    # self.listen_task = asyncio.create_task(self.task_listen_for_peers())
+    # self.cleanup_task = asyncio.create_task(self.task_cleanup_peers())
 
   async def stop(self):
     if self.broadcast_task: self.broadcast_task.cancel()
@@ -62,7 +61,6 @@ class UDPDiscovery(Discovery):
       await asyncio.gather(self.broadcast_task, self.listen_task, self.cleanup_task, return_exceptions=True)
 
   async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
-    if DEBUG_DISCOVERY >= 2: print("Starting peer discovery process...")
     if wait_for_peers > 0:
       while len(self.known_peers) < wait_for_peers:
         if DEBUG_DISCOVERY >= 2: print(f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers...")
@@ -70,25 +68,26 @@ class UDPDiscovery(Discovery):
     return [peer_handle for peer_handle, _, _ in self.known_peers.values()]
 
   async def task_broadcast_presence(self):
-    transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(lambda: asyncio.DatagramProtocol(), local_addr=("0.0.0.0", 0), family=socket.AF_INET)
-    sock = transport.get_extra_info("socket")
-    sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
-
-    message = json.dumps({
-      "type": "discovery",
-      "node_id": self.node_id,
-      "grpc_port": self.node_port,
-      "device_capabilities": self.device_capabilities.to_dict(),
-    }).encode("utf-8")
-
     while True:
       try:
+        message = json.dumps({
+          "type": "discovery",
+          "node_id": self.node_id,
+          "grpc_port": self.node_port,
+          "device_capabilities": self.device_capabilities.to_dict(),
+        }).encode("utf-8")
         if DEBUG_DISCOVERY >= 3: print(f"Broadcast presence: {message}")
+
+        transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(lambda: asyncio.DatagramProtocol(), local_addr=("0.0.0.0", 0), family=socket.AF_INET)
+        sock = transport.get_extra_info("socket")
+        sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
         transport.sendto(message, ("<broadcast>", self.broadcast_port))
       except Exception as e:
         print(f"Error in broadcast presence: {e}")
         print(traceback.format_exc())
       finally:
+        if transport:
+          transport.close()
         await asyncio.sleep(self.broadcast_interval)
 
   async def on_listen_message(self, data, addr):