Browse Source

allowed interface types

Alex Cheema 7 months ago
parent
commit
571b26c50e
3 changed files with 16 additions and 5 deletions
  1. 1 1
      .github/workflows/bench_job.yml
  2. 5 2
      exo/main.py
  3. 10 2
      exo/networking/udp/udp_discovery.py

+ 1 - 1
.github/workflows/bench_job.yml

@@ -74,7 +74,7 @@ jobs:
           export PATH="/usr/local/bin:/opt/homebrew/bin:$PATH"
 
           echo "Starting exo daemon..."
-          DEBUG=6 DEBUG_DISCOVERY=6 exo --node-id="${MY_NODE_ID}" --node-id-filter="${ALL_NODE_IDS}" --chatgpt-api-port 52415 > output1.log 2>&1 &
+          DEBUG=6 DEBUG_DISCOVERY=6 exo --node-id="${MY_NODE_ID}" --node-id-filter="${ALL_NODE_IDS}" --interface-type-filter="Ethernet" --chatgpt-api-port 52415 > output1.log 2>&1 &
           PID1=$!
           echo "Exo process started with PID: $PID1"
           tail -f output1.log &

+ 5 - 2
exo/main.py

@@ -59,6 +59,7 @@ parser.add_argument("--default-temp", type=float, help="Default token sampling t
 parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
 parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
 parser.add_argument("--node-id-filter", type=str, default=None, help="Comma separated list of allowed node IDs (only for UDP and Tailscale discovery)")
+parser.add_argument("--interface-type-filter", type=str, default=None, help="Comma separated list of allowed interface types (only for UDP discovery)")
 args = parser.parse_args()
 print(f"Selected inference engine: {args.inference_engine}")
 
@@ -90,8 +91,9 @@ if DEBUG >= 0:
   for chatgpt_api_endpoint in chatgpt_api_endpoints:
     print(f" - {terminal_link(chatgpt_api_endpoint)}")
 
-# Convert node-id-filter to list if provided
+# Convert node-id-filter and interface-type-filter to lists if provided
 allowed_node_ids = args.node_id_filter.split(',') if args.node_id_filter else None
+allowed_interface_types = args.interface_type_filter.split(',') if args.interface_type_filter else None
 
 if args.discovery_module == "udp":
   discovery = UDPDiscovery(
@@ -101,7 +103,8 @@ if args.discovery_module == "udp":
     args.broadcast_port,
     lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
     discovery_timeout=args.discovery_timeout,
-    allowed_node_ids=allowed_node_ids
+    allowed_node_ids=allowed_node_ids,
+    allowed_interface_types=allowed_interface_types
   )
 elif args.discovery_module == "tailscale":
   discovery = TailscaleDiscovery(

+ 10 - 2
exo/networking/udp/udp_discovery.py

@@ -3,7 +3,7 @@ import json
 import socket
 import time
 import traceback
-from typing import List, Dict, Callable, Tuple, Coroutine
+from typing import List, Dict, Callable, Tuple, Coroutine, Optional
 from exo.networking.discovery import Discovery
 from exo.networking.peer_handle import PeerHandle
 from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
@@ -45,7 +45,8 @@ class UDPDiscovery(Discovery):
     broadcast_interval: int = 2.5,
     discovery_timeout: int = 30,
     device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
-    allowed_node_ids: List[str] = None,
+    allowed_node_ids: Optional[List[str]] = None,
+    allowed_interface_types: Optional[List[str]] = None,
   ):
     self.node_id = node_id
     self.node_port = node_port
@@ -56,6 +57,7 @@ class UDPDiscovery(Discovery):
     self.discovery_timeout = discovery_timeout
     self.device_capabilities = device_capabilities
     self.allowed_node_ids = allowed_node_ids
+    self.allowed_interface_types = allowed_interface_types
     self.known_peers: Dict[str, Tuple[PeerHandle, float, float, int]] = {}
     self.broadcast_task = None
     self.listen_task = None
@@ -147,6 +149,12 @@ class UDPDiscovery(Discovery):
       peer_prio = message["priority"]
       peer_interface_name = message["interface_name"]
       peer_interface_type = message["interface_type"]
+
+      # Skip if interface type is not in allowed list
+      if self.allowed_interface_types and peer_interface_type not in self.allowed_interface_types:
+        if DEBUG_DISCOVERY >= 2: print(f"Ignoring peer {peer_id} as its interface type {peer_interface_type} is not in the allowed interface types list")
+        return
+
       device_capabilities = DeviceCapabilities(**message["device_capabilities"])
 
       if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":