|
@@ -3,13 +3,12 @@ import json
|
|
|
import socket
|
|
|
import time
|
|
|
import traceback
|
|
|
-from typing import List, Dict, Callable, Tuple, Coroutine, Type
|
|
|
+from typing import List, Dict, Callable, Tuple, Coroutine
|
|
|
from .discovery import Discovery
|
|
|
from .peer_handle import PeerHandle
|
|
|
from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
|
|
|
from exo import DEBUG_DISCOVERY
|
|
|
|
|
|
-
|
|
|
class ListenProtocol(asyncio.DatagramProtocol):
|
|
|
def __init__(self, on_message: Callable[[bytes, Tuple[str, int]], Coroutine]):
|
|
|
super().__init__()
|
|
@@ -29,30 +28,30 @@ class UDPDiscovery(Discovery):
|
|
|
node_id: str,
|
|
|
node_port: int,
|
|
|
listen_port: int,
|
|
|
+ broadcast_port: int,
|
|
|
create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
|
|
|
- broadcast_port: int = None,
|
|
|
broadcast_interval: int = 1,
|
|
|
- device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
|
|
|
discovery_timeout: int = 30,
|
|
|
+ device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
|
|
|
):
|
|
|
self.node_id = node_id
|
|
|
self.node_port = node_port
|
|
|
- self.device_capabilities = device_capabilities
|
|
|
self.listen_port = listen_port
|
|
|
- self.broadcast_port = broadcast_port if broadcast_port is not None else listen_port
|
|
|
+ self.broadcast_port = broadcast_port
|
|
|
+ self.create_peer_handle = create_peer_handle
|
|
|
self.broadcast_interval = broadcast_interval
|
|
|
+ self.discovery_timeout = discovery_timeout
|
|
|
+ self.device_capabilities = device_capabilities
|
|
|
self.known_peers: Dict[str, Tuple[PeerHandle, float, float]] = {}
|
|
|
self.broadcast_task = None
|
|
|
self.listen_task = None
|
|
|
self.cleanup_task = None
|
|
|
- self.discovery_timeout = discovery_timeout
|
|
|
- self.create_peer_handle = create_peer_handle
|
|
|
|
|
|
async def start(self):
|
|
|
self.device_capabilities = device_capabilities()
|
|
|
self.broadcast_task = asyncio.create_task(self.task_broadcast_presence())
|
|
|
- self.listen_task = asyncio.create_task(self.task_listen_for_peers())
|
|
|
- self.cleanup_task = asyncio.create_task(self.task_cleanup_peers())
|
|
|
+ # self.listen_task = asyncio.create_task(self.task_listen_for_peers())
|
|
|
+ # self.cleanup_task = asyncio.create_task(self.task_cleanup_peers())
|
|
|
|
|
|
async def stop(self):
|
|
|
if self.broadcast_task: self.broadcast_task.cancel()
|
|
@@ -62,7 +61,6 @@ class UDPDiscovery(Discovery):
|
|
|
await asyncio.gather(self.broadcast_task, self.listen_task, self.cleanup_task, return_exceptions=True)
|
|
|
|
|
|
async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
|
|
|
- if DEBUG_DISCOVERY >= 2: print("Starting peer discovery process...")
|
|
|
if wait_for_peers > 0:
|
|
|
while len(self.known_peers) < wait_for_peers:
|
|
|
if DEBUG_DISCOVERY >= 2: print(f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers...")
|
|
@@ -70,25 +68,26 @@ class UDPDiscovery(Discovery):
|
|
|
return [peer_handle for peer_handle, _, _ in self.known_peers.values()]
|
|
|
|
|
|
async def task_broadcast_presence(self):
|
|
|
- transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(lambda: asyncio.DatagramProtocol(), local_addr=("0.0.0.0", 0), family=socket.AF_INET)
|
|
|
- sock = transport.get_extra_info("socket")
|
|
|
- sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
|
|
|
-
|
|
|
- message = json.dumps({
|
|
|
- "type": "discovery",
|
|
|
- "node_id": self.node_id,
|
|
|
- "grpc_port": self.node_port,
|
|
|
- "device_capabilities": self.device_capabilities.to_dict(),
|
|
|
- }).encode("utf-8")
|
|
|
-
|
|
|
while True:
|
|
|
try:
|
|
|
+ message = json.dumps({
|
|
|
+ "type": "discovery",
|
|
|
+ "node_id": self.node_id,
|
|
|
+ "grpc_port": self.node_port,
|
|
|
+ "device_capabilities": self.device_capabilities.to_dict(),
|
|
|
+ }).encode("utf-8")
|
|
|
if DEBUG_DISCOVERY >= 3: print(f"Broadcast presence: {message}")
|
|
|
+
|
|
|
+ transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(lambda: asyncio.DatagramProtocol(), local_addr=("0.0.0.0", 0), family=socket.AF_INET)
|
|
|
+ sock = transport.get_extra_info("socket")
|
|
|
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
|
|
|
transport.sendto(message, ("<broadcast>", self.broadcast_port))
|
|
|
except Exception as e:
|
|
|
print(f"Error in broadcast presence: {e}")
|
|
|
print(traceback.format_exc())
|
|
|
finally:
|
|
|
+ if transport:
|
|
|
+ transport.close()
|
|
|
await asyncio.sleep(self.broadcast_interval)
|
|
|
|
|
|
async def on_listen_message(self, data, addr):
|