|
@@ -7,7 +7,7 @@ from typing import List, Dict, Callable, Tuple, Coroutine
|
|
|
from .discovery import Discovery
|
|
|
from .peer_handle import PeerHandle
|
|
|
from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
|
|
|
-from exo.helpers import DEBUG, DEBUG_DISCOVERY
|
|
|
+from exo.helpers import DEBUG, DEBUG_DISCOVERY, get_all_ip_addresses
|
|
|
|
|
|
class ListenProtocol(asyncio.DatagramProtocol):
|
|
|
def __init__(self, on_message: Callable[[bytes, Tuple[str, int]], Coroutine]):
|
|
@@ -22,6 +22,17 @@ class ListenProtocol(asyncio.DatagramProtocol):
|
|
|
asyncio.create_task(self.on_message(data, addr))
|
|
|
|
|
|
|
|
|
+class BroadcastProtocol(asyncio.DatagramProtocol):
|
|
|
+ def __init__(self, message: str, broadcast_port: int):
|
|
|
+ self.message = message
|
|
|
+ self.broadcast_port = broadcast_port
|
|
|
+
|
|
|
+ def connection_made(self, transport):
|
|
|
+ sock = transport.get_extra_info("socket")
|
|
|
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
|
|
|
+ transport.sendto(self.message.encode("utf-8"), ("<broadcast>", self.broadcast_port))
|
|
|
+
|
|
|
+
|
|
|
class UDPDiscovery(Discovery):
|
|
|
def __init__(
|
|
|
self,
|
|
@@ -68,31 +79,40 @@ class UDPDiscovery(Discovery):
|
|
|
return [peer_handle for peer_handle, _, _ in self.known_peers.values()]
|
|
|
|
|
|
async def task_broadcast_presence(self):
|
|
|
+ message = json.dumps({
|
|
|
+ "type": "discovery",
|
|
|
+ "node_id": self.node_id,
|
|
|
+ "grpc_port": self.node_port,
|
|
|
+ "device_capabilities": self.device_capabilities.to_dict(),
|
|
|
+ })
|
|
|
+
|
|
|
+ if DEBUG_DISCOVERY >= 2:
|
|
|
+ print("Starting task_broadcast_presence...")
|
|
|
+ print(f"\nBroadcast message: {message}")
|
|
|
+
|
|
|
while True:
|
|
|
- try:
|
|
|
- message = json.dumps({
|
|
|
- "type": "discovery",
|
|
|
- "node_id": self.node_id,
|
|
|
- "grpc_port": self.node_port,
|
|
|
- "device_capabilities": self.device_capabilities.to_dict(),
|
|
|
- }).encode("utf-8")
|
|
|
- if DEBUG_DISCOVERY >= 3: print(f"Broadcast presence: {message}")
|
|
|
-
|
|
|
- transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(lambda: asyncio.DatagramProtocol(), local_addr=("0.0.0.0", 0), family=socket.AF_INET)
|
|
|
- sock = transport.get_extra_info("socket")
|
|
|
- sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
|
|
|
- transport.sendto(message, ("<broadcast>", self.broadcast_port))
|
|
|
- except Exception as e:
|
|
|
- print(f"Error in broadcast presence: {e}")
|
|
|
- print(traceback.format_exc())
|
|
|
- finally:
|
|
|
- if transport:
|
|
|
- try:
|
|
|
- transport.close()
|
|
|
- except:
|
|
|
- if DEBUG_DISCOVERY >= 2: print(f"Error closing transport: {e}")
|
|
|
- if DEBUG_DISCOVERY >= 2: traceback.print_exc()
|
|
|
- await asyncio.sleep(self.broadcast_interval)
|
|
|
+ # Explicitly broadcasting on all assigned ips since broadcasting on `0.0.0.0` on MacOS does not broadcast over
|
|
|
+ # the Thunderbolt bridge when other connection modalities exist such as WiFi or Ethernet
|
|
|
+ for addr in get_all_ip_addresses():
|
|
|
+ transport = None
|
|
|
+ try:
|
|
|
+ transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
|
|
|
+ lambda: BroadcastProtocol(message, self.broadcast_port),
|
|
|
+ local_addr=(addr, 0),
|
|
|
+ family=socket.AF_INET
|
|
|
+ )
|
|
|
+ if DEBUG_DISCOVERY >= 3:
|
|
|
+ print(f"Broadcasting presence at ({addr})")
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error in broadcast presence ({addr}): {e}")
|
|
|
+ finally:
|
|
|
+ if transport:
|
|
|
+ try:
|
|
|
+ transport.close()
|
|
|
+ except Exception as e:
|
|
|
+ if DEBUG_DISCOVERY >= 2: print(f"Error closing transport: {e}")
|
|
|
+ if DEBUG_DISCOVERY >= 2: traceback.print_exc()
|
|
|
+ await asyncio.sleep(self.broadcast_interval)
|
|
|
|
|
|
async def on_listen_message(self, data, addr):
|
|
|
if not data:
|
|
@@ -120,7 +140,8 @@ class UDPDiscovery(Discovery):
|
|
|
peer_port = message["grpc_port"]
|
|
|
device_capabilities = DeviceCapabilities(**message["device_capabilities"])
|
|
|
if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
|
|
|
- if DEBUG >= 1: print(f"Adding {peer_id=} at {peer_host}:{peer_port}. Replace existing peer_id: {peer_id in self.known_peers}")
|
|
|
+ if DEBUG >= 1: print(
|
|
|
+ f"Adding {peer_id=} at {peer_host}:{peer_port}. Replace existing peer_id: {peer_id in self.known_peers}")
|
|
|
self.known_peers[peer_id] = (
|
|
|
self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities),
|
|
|
time.time(),
|
|
@@ -129,7 +150,8 @@ class UDPDiscovery(Discovery):
|
|
|
self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time())
|
|
|
|
|
|
async def task_listen_for_peers(self):
|
|
|
- await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message), local_addr=("0.0.0.0", self.listen_port))
|
|
|
+ await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message),
|
|
|
+ local_addr=("0.0.0.0", self.listen_port))
|
|
|
if DEBUG_DISCOVERY >= 2: print("Started listen task")
|
|
|
|
|
|
async def task_cleanup_peers(self):
|
|
@@ -140,7 +162,9 @@ class UDPDiscovery(Discovery):
|
|
|
peer_handle.id() for peer_handle, connected_at, last_seen in self.known_peers.values()
|
|
|
if (not await peer_handle.is_connected() and current_time - connected_at > self.discovery_timeout) or current_time - last_seen > self.discovery_timeout
|
|
|
]
|
|
|
- if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}" for peer_handle, connected_at, last_seen in self.known_peers.values()})
|
|
|
+ if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {
|
|
|
+ peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}" for
|
|
|
+ peer_handle, connected_at, last_seen in self.known_peers.values()})
|
|
|
for peer_id in peers_to_remove:
|
|
|
if peer_id in self.known_peers: del self.known_peers[peer_id]
|
|
|
if DEBUG_DISCOVERY >= 2: print(f"Removed peer {peer_id} due to inactivity.")
|