|
@@ -3,7 +3,7 @@ import json
|
|
|
import socket
|
|
|
import time
|
|
|
import traceback
|
|
|
-from typing import List, Dict, Callable, Tuple, Coroutine
|
|
|
+from typing import List, Dict, Callable, Tuple, Coroutine, Optional
|
|
|
from exo.networking.discovery import Discovery
|
|
|
from exo.networking.peer_handle import PeerHandle
|
|
|
from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
|
|
@@ -45,7 +45,8 @@ class UDPDiscovery(Discovery):
|
|
|
broadcast_interval: int = 2.5,
|
|
|
discovery_timeout: int = 30,
|
|
|
device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
|
|
|
- allowed_node_ids: List[str] = None,
|
|
|
+ allowed_node_ids: Optional[List[str]] = None,
|
|
|
+ allowed_interface_types: Optional[List[str]] = None,
|
|
|
):
|
|
|
self.node_id = node_id
|
|
|
self.node_port = node_port
|
|
@@ -56,6 +57,7 @@ class UDPDiscovery(Discovery):
|
|
|
self.discovery_timeout = discovery_timeout
|
|
|
self.device_capabilities = device_capabilities
|
|
|
self.allowed_node_ids = allowed_node_ids
|
|
|
+ self.allowed_interface_types = allowed_interface_types
|
|
|
self.known_peers: Dict[str, Tuple[PeerHandle, float, float, int]] = {}
|
|
|
self.broadcast_task = None
|
|
|
self.listen_task = None
|
|
@@ -147,6 +149,12 @@ class UDPDiscovery(Discovery):
|
|
|
peer_prio = message["priority"]
|
|
|
peer_interface_name = message["interface_name"]
|
|
|
peer_interface_type = message["interface_type"]
|
|
|
+
|
|
|
+ # Skip if interface type is not in allowed list
|
|
|
+ if self.allowed_interface_types and peer_interface_type not in self.allowed_interface_types:
|
|
|
+ if DEBUG_DISCOVERY >= 2: print(f"Ignoring peer {peer_id} as its interface type {peer_interface_type} is not in the allowed interface types list")
|
|
|
+ return
|
|
|
+
|
|
|
device_capabilities = DeviceCapabilities(**message["device_capabilities"])
|
|
|
|
|
|
if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
|