Forráskód Böngészése

add --node-id-filter command line arg to filter by node id

Alex Cheema 5 hónapja
szülő
commit
d34e67a2c1

+ 8 - 2
exo/main.py

@@ -58,6 +58,7 @@ parser.add_argument("--prompt", type=str, help="Prompt for the model when using
 parser.add_argument("--default-temp", type=float, help="Default token sampling temperature", default=0.0)
 parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
 parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
+parser.add_argument("--node-id-filter", type=str, default=None, help="Comma separated list of allowed node IDs (only for UDP and Tailscale discovery)")
 args = parser.parse_args()
 print(f"Selected inference engine: {args.inference_engine}")
 
@@ -89,6 +90,9 @@ if DEBUG >= 0:
   for chatgpt_api_endpoint in chatgpt_api_endpoints:
     print(f" - {terminal_link(chatgpt_api_endpoint)}")
 
+# Convert node-id-filter to list if provided
+allowed_node_ids = args.node_id_filter.split(',') if args.node_id_filter else None
+
 if args.discovery_module == "udp":
   discovery = UDPDiscovery(
     args.node_id,
@@ -96,7 +100,8 @@ if args.discovery_module == "udp":
     args.listen_port,
     args.broadcast_port,
     lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities),
-    discovery_timeout=args.discovery_timeout
+    discovery_timeout=args.discovery_timeout,
+    allowed_node_ids=allowed_node_ids
   )
 elif args.discovery_module == "tailscale":
   discovery = TailscaleDiscovery(
@@ -105,7 +110,8 @@ elif args.discovery_module == "tailscale":
     lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities),
     discovery_timeout=args.discovery_timeout,
     tailscale_api_key=args.tailscale_api_key,
-    tailnet=args.tailnet_name
+    tailnet=args.tailnet_name,
+    allowed_node_ids=allowed_node_ids
   )
 elif args.discovery_module == "manual":
   if not args.discovery_config_path:

+ 6 - 0
exo/networking/tailscale/tailscale_discovery.py

@@ -21,6 +21,7 @@ class TailscaleDiscovery(Discovery):
     device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
     tailscale_api_key: str = None,
     tailnet: str = None,
+    allowed_node_ids: List[str] = None,
   ):
     self.node_id = node_id
     self.node_port = node_port
@@ -34,6 +35,7 @@ class TailscaleDiscovery(Discovery):
     self.cleanup_task = None
     self.tailscale_api_key = tailscale_api_key
     self.tailnet = tailnet
+    self.allowed_node_ids = allowed_node_ids
     self._device_id = None
     self.update_task = None
 
@@ -84,6 +86,10 @@ class TailscaleDiscovery(Discovery):
             if DEBUG_DISCOVERY >= 4: print(f"{device.device_id} does not have exo node attributes. skipping.")
             continue
 
+          if self.allowed_node_ids and peer_id not in self.allowed_node_ids:
+            if DEBUG_DISCOVERY >= 2: print(f"Ignoring peer {peer_id} as it's not in the allowed node IDs list")
+            continue
+
           if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
             new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
             if not await new_peer_handle.health_check():

+ 8 - 0
exo/networking/udp/udp_discovery.py

@@ -45,6 +45,7 @@ class UDPDiscovery(Discovery):
     broadcast_interval: int = 1,
     discovery_timeout: int = 30,
     device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
+    allowed_node_ids: List[str] = None,
   ):
     self.node_id = node_id
     self.node_port = node_port
@@ -54,6 +55,7 @@ class UDPDiscovery(Discovery):
     self.broadcast_interval = broadcast_interval
     self.discovery_timeout = discovery_timeout
     self.device_capabilities = device_capabilities
+    self.allowed_node_ids = allowed_node_ids
     self.known_peers: Dict[str, Tuple[PeerHandle, float, float, int]] = {}
     self.broadcast_task = None
     self.listen_task = None
@@ -133,6 +135,12 @@ class UDPDiscovery(Discovery):
 
     if message["type"] == "discovery" and message["node_id"] != self.node_id:
       peer_id = message["node_id"]
+      
+      # Skip if peer_id is not in allowed list
+      if self.allowed_node_ids and peer_id not in self.allowed_node_ids:
+        if DEBUG_DISCOVERY >= 2: print(f"Ignoring peer {peer_id} as it's not in the allowed node IDs list")
+        return
+
       peer_host = addr[0]
       peer_port = message["grpc_port"]
       peer_prio = message["priority"]