Browse Source

prioritise network interfaces and display in the tui

Alex Cheema 6 months ago
parent
commit
b8c4b46fe9

+ 31 - 3
exo/helpers.py

@@ -222,7 +222,7 @@ def pretty_print_bytes_per_second(bytes_per_second: int) -> str:
     return f"{bytes_per_second / (1024 ** 4):.2f} TB/s"
     return f"{bytes_per_second / (1024 ** 4):.2f} TB/s"
 
 
 
 
-def get_all_ip_addresses():
+def get_all_ip_addresses_and_interfaces():
   try:
   try:
     ip_addresses = []
     ip_addresses = []
     for interface in netifaces.interfaces():
     for interface in netifaces.interfaces():
@@ -230,12 +230,40 @@ def get_all_ip_addresses():
       if netifaces.AF_INET in ifaddresses:
       if netifaces.AF_INET in ifaddresses:
         for link in ifaddresses[netifaces.AF_INET]:
         for link in ifaddresses[netifaces.AF_INET]:
           ip = link['addr']
           ip = link['addr']
-          ip_addresses.append(ip)
+          ip_addresses.append((ip, interface))
     return list(set(ip_addresses))
     return list(set(ip_addresses))
   except:
   except:
     if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
     if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
-    return ["localhost"]
+    return [("localhost", "lo")]
 
 
+def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
+  # Local container/virtual interfaces
+  if (ifname.startswith(('docker', 'br-', 'veth', 'cni', 'flannel', 'calico', 'weave')) or
+    'bridge' in ifname):
+    return (7, "Container Virtual")
+
+  # Loopback interface
+  if ifname.startswith('lo'):
+    return (6, "Loopback")
+
+  # Thunderbolt/10GbE detection
+  if ifname.startswith(('tb', 'nx', 'ten')):
+    return (5, "Thunderbolt/10GbE")
+
+  # Regular ethernet detection
+  if ifname.startswith(('eth', 'en')) and not ifname.startswith(('en1', 'en0')):
+    return (4, "Ethernet")
+
+  # WiFi detection
+  if ifname.startswith(('wlan', 'wifi', 'wl')) or ifname in ['en0', 'en1']:
+    return (3, "WiFi")
+
+  # Non-local virtual interfaces (VPNs, tunnels)
+  if ifname.startswith(('tun', 'tap', 'vtun', 'utun', 'gif', 'stf', 'awdl', 'llw')):
+    return (1, "External Virtual")
+
+  # Other physical interfaces
+  return (2, "Other")
 
 
 async def shutdown(signal, loop, server):
 async def shutdown(signal, loop, server):
   """Gracefully shutdown the server and close the asyncio loop."""
   """Gracefully shutdown the server and close the asyncio loop."""

+ 6 - 6
exo/main.py

@@ -21,7 +21,7 @@ from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWe
 from exo.api import ChatGPTAPI
 from exo.api import ChatGPTAPI
 from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
 from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
 from exo.download.hf.hf_shard_download import HFShardDownloader
 from exo.download.hf.hf_shard_download import HFShardDownloader
-from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link, shutdown
+from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses_and_interfaces, terminal_link, shutdown
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.inference.tokenizers import resolve_tokenizer
@@ -80,8 +80,8 @@ if args.node_port is None:
   if DEBUG >= 1: print(f"Using available port: {args.node_port}")
   if DEBUG >= 1: print(f"Using available port: {args.node_port}")
 
 
 args.node_id = args.node_id or get_or_create_node_id()
 args.node_id = args.node_id or get_or_create_node_id()
-chatgpt_api_endpoints = [f"http://{ip}:{args.chatgpt_api_port}/v1/chat/completions" for ip in get_all_ip_addresses()]
-web_chat_urls = [f"http://{ip}:{args.chatgpt_api_port}" for ip in get_all_ip_addresses()]
+chatgpt_api_endpoints = [f"http://{ip}:{args.chatgpt_api_port}/v1/chat/completions" for ip, _ in get_all_ip_addresses_and_interfaces()]
+web_chat_urls = [f"http://{ip}:{args.chatgpt_api_port}" for ip, _ in get_all_ip_addresses_and_interfaces()]
 if DEBUG >= 0:
 if DEBUG >= 0:
   print("Chat interface started:")
   print("Chat interface started:")
   for web_chat_url in web_chat_urls:
   for web_chat_url in web_chat_urls:
@@ -99,7 +99,7 @@ if args.discovery_module == "udp":
     args.node_port,
     args.node_port,
     args.listen_port,
     args.listen_port,
     args.broadcast_port,
     args.broadcast_port,
-    lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities),
+    lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
     discovery_timeout=args.discovery_timeout,
     discovery_timeout=args.discovery_timeout,
     allowed_node_ids=allowed_node_ids
     allowed_node_ids=allowed_node_ids
   )
   )
@@ -107,7 +107,7 @@ elif args.discovery_module == "tailscale":
   discovery = TailscaleDiscovery(
   discovery = TailscaleDiscovery(
     args.node_id,
     args.node_id,
     args.node_port,
     args.node_port,
-    lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities),
+    lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
     discovery_timeout=args.discovery_timeout,
     discovery_timeout=args.discovery_timeout,
     tailscale_api_key=args.tailscale_api_key,
     tailscale_api_key=args.tailscale_api_key,
     tailnet=args.tailnet_name,
     tailnet=args.tailnet_name,
@@ -116,7 +116,7 @@ elif args.discovery_module == "tailscale":
 elif args.discovery_module == "manual":
 elif args.discovery_module == "manual":
   if not args.discovery_config_path:
   if not args.discovery_config_path:
     raise ValueError(f"--discovery-config-path is required when using manual discovery. Please provide a path to a config json file.")
     raise ValueError(f"--discovery-config-path is required when using manual discovery. Please provide a path to a config json file.")
-  discovery = ManualDiscovery(args.discovery_config_path, args.node_id, create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
+  discovery = ManualDiscovery(args.discovery_config_path, args.node_id, create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
 topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None
 topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None
 node = StandardNode(
 node = StandardNode(
   args.node_id,
   args.node_id,

+ 12 - 5
exo/networking/grpc/grpc_peer_handle.py

@@ -14,9 +14,10 @@ from exo.helpers import DEBUG
 
 
 
 
 class GRPCPeerHandle(PeerHandle):
 class GRPCPeerHandle(PeerHandle):
-  def __init__(self, _id: str, address: str, device_capabilities: DeviceCapabilities):
+  def __init__(self, _id: str, address: str, desc: str, device_capabilities: DeviceCapabilities):
     self._id = _id
     self._id = _id
     self.address = address
     self.address = address
+    self.desc = desc
     self._device_capabilities = device_capabilities
     self._device_capabilities = device_capabilities
     self.channel = None
     self.channel = None
     self.stub = None
     self.stub = None
@@ -27,6 +28,9 @@ class GRPCPeerHandle(PeerHandle):
   def addr(self) -> str:
   def addr(self) -> str:
     return self.address
     return self.address
 
 
+  def description(self) -> str:
+    return self.desc
+
   def device_capabilities(self) -> DeviceCapabilities:
   def device_capabilities(self) -> DeviceCapabilities:
     return self._device_capabilities
     return self._device_capabilities
 
 
@@ -119,12 +123,15 @@ class GRPCPeerHandle(PeerHandle):
     topology = Topology()
     topology = Topology()
     for node_id, capabilities in response.nodes.items():
     for node_id, capabilities in response.nodes.items():
       device_capabilities = DeviceCapabilities(
       device_capabilities = DeviceCapabilities(
-        model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
+        model=capabilities.model,
+        chip=capabilities.chip,
+        memory=capabilities.memory,
+        flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
       )
       )
       topology.update_node(node_id, device_capabilities)
       topology.update_node(node_id, device_capabilities)
-    for node_id, peers in response.peer_graph.items():
-      for peer_id in peers.peer_ids:
-        topology.add_edge(node_id, peer_id)
+    for node_id, peer_connections in response.peer_graph.items():
+      for conn in peer_connections.connections:
+        topology.add_edge(node_id, conn.to_id, conn.description)
     return topology
     return topology
 
 
   async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
   async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:

+ 9 - 1
exo/networking/grpc/grpc_server.py

@@ -96,7 +96,15 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
         )
         )
       for node_id, cap in topology.nodes.items()
       for node_id, cap in topology.nodes.items()
     }
     }
-    peer_graph = {node_id: node_service_pb2.Peers(peer_ids=peers) for node_id, peers in topology.peer_graph.items()}
+    peer_graph = {
+      node_id: node_service_pb2.PeerConnections(
+        connections=[
+          node_service_pb2.PeerConnection(to_id=conn.to_id, description=conn.description)
+          for conn in connections
+        ]
+      )
+      for node_id, connections in topology.peer_graph.items()
+    }
     if DEBUG >= 5: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
     if DEBUG >= 5: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
     return node_service_pb2.Topology(nodes=nodes, peer_graph=peer_graph)
     return node_service_pb2.Topology(nodes=nodes, peer_graph=peer_graph)
 
 

+ 8 - 3
exo/networking/grpc/node_service.proto

@@ -53,11 +53,16 @@ message CollectTopologyRequest {
 
 
 message Topology {
 message Topology {
   map<string, DeviceCapabilities> nodes = 1;
   map<string, DeviceCapabilities> nodes = 1;
-  map<string, Peers> peer_graph = 2;
+  map<string, PeerConnections> peer_graph = 2;
 }
 }
 
 
-message Peers {
-    repeated string peer_ids = 1;
+message PeerConnection {
+  string to_id = 1;
+  optional string description = 2;
+}
+
+message PeerConnections {
+  repeated PeerConnection connections = 1;
 }
 }
 
 
 message DeviceFlops {
 message DeviceFlops {

File diff suppressed because it is too large
+ 11 - 1
exo/networking/grpc/node_service_pb2.py


+ 2 - 7
exo/networking/grpc/node_service_pb2_grpc.py

@@ -5,10 +5,8 @@ import warnings
 
 
 from . import node_service_pb2 as node__service__pb2
 from . import node_service_pb2 as node__service__pb2
 
 
-GRPC_GENERATED_VERSION = '1.64.1'
+GRPC_GENERATED_VERSION = '1.68.0'
 GRPC_VERSION = grpc.__version__
 GRPC_VERSION = grpc.__version__
-EXPECTED_ERROR_RELEASE = '1.65.0'
-SCHEDULED_RELEASE_DATE = 'June 25, 2024'
 _version_not_supported = False
 _version_not_supported = False
 
 
 try:
 try:
@@ -18,15 +16,12 @@ except ImportError:
     _version_not_supported = True
     _version_not_supported = True
 
 
 if _version_not_supported:
 if _version_not_supported:
-    warnings.warn(
+    raise RuntimeError(
         f'The grpc package installed is at version {GRPC_VERSION},'
         f'The grpc package installed is at version {GRPC_VERSION},'
         + f' but the generated code in node_service_pb2_grpc.py depends on'
         + f' but the generated code in node_service_pb2_grpc.py depends on'
         + f' grpcio>={GRPC_GENERATED_VERSION}.'
         + f' grpcio>={GRPC_GENERATED_VERSION}.'
         + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
         + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
         + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
         + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
-        + f' This warning will become an error in {EXPECTED_ERROR_RELEASE},'
-        + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.',
-        RuntimeWarning
     )
     )
 
 
 
 

+ 1 - 1
exo/networking/manual/manual_discovery.py

@@ -53,7 +53,7 @@ class ManualDiscovery(Discovery):
           peer = self.known_peers.get(peer_id)
           peer = self.known_peers.get(peer_id)
           if not peer:
           if not peer:
             if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} not found in known peers. Adding.")
             if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} not found in known peers. Adding.")
-            peer = self.create_peer_handle(peer_id, f"{peer_config.address}:{peer_config.port}", peer_config.device_capabilities)
+            peer = self.create_peer_handle(peer_id, f"{peer_config.address}:{peer_config.port}", "MAN", peer_config.device_capabilities)
           is_healthy = await peer.health_check()
           is_healthy = await peer.health_check()
           if is_healthy:
           if is_healthy:
             if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")
             if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")

+ 5 - 5
exo/networking/manual/test_manual_discovery.py

@@ -14,7 +14,7 @@ class TestSingleNodeManualDiscovery(unittest.IsolatedAsyncioTestCase):
   async def asyncSetUp(self):
   async def asyncSetUp(self):
     self.peer1 = mock.AsyncMock()
     self.peer1 = mock.AsyncMock()
     self.peer1.connect = mock.AsyncMock()
     self.peer1.connect = mock.AsyncMock()
-    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, device_capabilities: self.peer1)
+    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
     _ = self.discovery1.start()
     _ = self.discovery1.start()
 
 
   async def asyncTearDown(self):
   async def asyncTearDown(self):
@@ -33,8 +33,8 @@ class TestManualDiscovery(unittest.IsolatedAsyncioTestCase):
     self.peer2 = mock.AsyncMock()
     self.peer2 = mock.AsyncMock()
     self.peer1.connect = mock.AsyncMock()
     self.peer1.connect = mock.AsyncMock()
     self.peer2.connect = mock.AsyncMock()
     self.peer2.connect = mock.AsyncMock()
-    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, device_capabilities: self.peer1)
-    self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, device_capabilities: self.peer2)
+    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
+    self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2)
     await self.discovery1.start()
     await self.discovery1.start()
     await self.discovery2.start()
     await self.discovery2.start()
 
 
@@ -63,8 +63,8 @@ class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
     self.server2 = GRPCServer(self.node2, config.peers["node2"].address, config.peers["node2"].port)
     self.server2 = GRPCServer(self.node2, config.peers["node2"].address, config.peers["node2"].port)
     await self.server1.start()
     await self.server1.start()
     await self.server2.start()
     await self.server2.start()
-    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
-    self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
+    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
+    self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
     await self.discovery1.start()
     await self.discovery1.start()
     await self.discovery2.start()
     await self.discovery2.start()
 
 

+ 4 - 0
exo/networking/peer_handle.py

@@ -15,6 +15,10 @@ class PeerHandle(ABC):
   def addr(self) -> str:
   def addr(self) -> str:
     pass
     pass
 
 
+  @abstractmethod
+  def description(self) -> str:
+    pass
+
   @abstractmethod
   @abstractmethod
   def device_capabilities(self) -> DeviceCapabilities:
   def device_capabilities(self) -> DeviceCapabilities:
     pass
     pass

+ 2 - 2
exo/networking/tailscale/tailscale_discovery.py

@@ -14,7 +14,7 @@ class TailscaleDiscovery(Discovery):
     self,
     self,
     node_id: str,
     node_id: str,
     node_port: int,
     node_port: int,
-    create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
+    create_peer_handle: Callable[[str, str, str, DeviceCapabilities], PeerHandle],
     discovery_interval: int = 5,
     discovery_interval: int = 5,
     discovery_timeout: int = 30,
     discovery_timeout: int = 30,
     update_interval: int = 15,
     update_interval: int = 15,
@@ -91,7 +91,7 @@ class TailscaleDiscovery(Discovery):
             continue
             continue
 
 
           if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
           if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
-            new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
+            new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", "TS", device_capabilities)
             if not await new_peer_handle.health_check():
             if not await new_peer_handle.health_check():
               if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} is not healthy. Skipping.")
               if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} is not healthy. Skipping.")
               continue
               continue

+ 1 - 1
exo/networking/tailscale/test_tailscale_discovery.py

@@ -13,7 +13,7 @@ class TestTailscaleDiscovery(unittest.IsolatedAsyncioTestCase):
     self.discovery = TailscaleDiscovery(
     self.discovery = TailscaleDiscovery(
       node_id="test_node",
       node_id="test_node",
       node_port=50051,
       node_port=50051,
-      create_peer_handle=lambda peer_id, address, device_capabilities: unittest.mock.Mock(spec=PeerHandle, id=lambda: peer_id),
+      create_peer_handle=lambda peer_id, address, description, device_capabilities: unittest.mock.Mock(spec=PeerHandle, id=lambda: peer_id),
       tailscale_api_key=self.tailscale_api_key,
       tailscale_api_key=self.tailscale_api_key,
       tailnet=self.tailnet
       tailnet=self.tailnet
     )
     )

+ 4 - 4
exo/networking/udp/test_udp_discovery.py

@@ -13,8 +13,8 @@ class TestUDPDiscovery(unittest.IsolatedAsyncioTestCase):
     self.peer2 = mock.AsyncMock()
     self.peer2 = mock.AsyncMock()
     self.peer1.connect = mock.AsyncMock()
     self.peer1.connect = mock.AsyncMock()
     self.peer2.connect = mock.AsyncMock()
     self.peer2.connect = mock.AsyncMock()
-    self.discovery1 = UDPDiscovery("discovery1", 50051, 5678, 5679, create_peer_handle=lambda peer_id, address, device_capabilities: self.peer1)
-    self.discovery2 = UDPDiscovery("discovery2", 50052, 5679, 5678, create_peer_handle=lambda peer_id, address, device_capabilities: self.peer2)
+    self.discovery1 = UDPDiscovery("discovery1", 50051, 5678, 5679, create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
+    self.discovery2 = UDPDiscovery("discovery2", 50052, 5679, 5678, create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2)
     await self.discovery1.start()
     await self.discovery1.start()
     await self.discovery2.start()
     await self.discovery2.start()
 
 
@@ -41,8 +41,8 @@ class TestUDPDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
     self.server2 = GRPCServer(self.node2, "localhost", 50054)
     self.server2 = GRPCServer(self.node2, "localhost", 50054)
     await self.server1.start()
     await self.server1.start()
     await self.server2.start()
     await self.server2.start()
-    self.discovery1 = UDPDiscovery("discovery1", 50053, 5678, 5679, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
-    self.discovery2 = UDPDiscovery("discovery2", 50054, 5679, 5678, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
+    self.discovery1 = UDPDiscovery("discovery1", 50053, 5678, 5679, lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
+    self.discovery2 = UDPDiscovery("discovery2", 50054, 5679, 5678, lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
     await self.discovery1.start()
     await self.discovery1.start()
     await self.discovery2.start()
     await self.discovery2.start()
 
 

+ 12 - 11
exo/networking/udp/udp_discovery.py

@@ -7,7 +7,7 @@ from typing import List, Dict, Callable, Tuple, Coroutine
 from exo.networking.discovery import Discovery
 from exo.networking.discovery import Discovery
 from exo.networking.peer_handle import PeerHandle
 from exo.networking.peer_handle import PeerHandle
 from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
 from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
-from exo.helpers import DEBUG, DEBUG_DISCOVERY, get_all_ip_addresses
+from exo.helpers import DEBUG, DEBUG_DISCOVERY, get_all_ip_addresses_and_interfaces, get_interface_priority_and_type
 
 
 
 
 class ListenProtocol(asyncio.DatagramProtocol):
 class ListenProtocol(asyncio.DatagramProtocol):
@@ -41,7 +41,7 @@ class UDPDiscovery(Discovery):
     node_port: int,
     node_port: int,
     listen_port: int,
     listen_port: int,
     broadcast_port: int,
     broadcast_port: int,
-    create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
+    create_peer_handle: Callable[[str, str, str, DeviceCapabilities], PeerHandle],
     broadcast_interval: int = 1,
     broadcast_interval: int = 1,
     discovery_timeout: int = 30,
     discovery_timeout: int = 30,
     device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
     device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
@@ -87,27 +87,27 @@ class UDPDiscovery(Discovery):
     while True:
     while True:
       # Explicitly broadcasting on all assigned ips since broadcasting on `0.0.0.0` on MacOS does not broadcast over
       # Explicitly broadcasting on all assigned ips since broadcasting on `0.0.0.0` on MacOS does not broadcast over
       # the Thunderbolt bridge when other connection modalities exist such as WiFi or Ethernet
       # the Thunderbolt bridge when other connection modalities exist such as WiFi or Ethernet
-      for addr in get_all_ip_addresses():
+      for addr, interface_name in get_all_ip_addresses_and_interfaces():
+        interface_priority, _ = get_interface_priority_and_type(interface_name)
         message = json.dumps({
         message = json.dumps({
           "type": "discovery",
           "type": "discovery",
           "node_id": self.node_id,
           "node_id": self.node_id,
           "grpc_port": self.node_port,
           "grpc_port": self.node_port,
           "device_capabilities": self.device_capabilities.to_dict(),
           "device_capabilities": self.device_capabilities.to_dict(),
-          "priority": 1,  # For now, every interface has the same priority. We can make this better by prioriting interfaces based on bandwidth, latency, and jitter e.g. prioritise Thunderbolt over WiFi.
+          "priority": interface_priority, # TODO: Prioritise interfaces based on bandwidth, latency, and jitter e.g. prioritise Thunderbolt over WiFi.
+          "interface_name": interface_name,
         })
         })
-        if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr}): {message}")
+        if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr} - {interface_name} - {interface_priority}): {message}")
 
 
         transport = None
         transport = None
         try:
         try:
           transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(lambda: BroadcastProtocol(message, self.broadcast_port), local_addr=(addr, 0), family=socket.AF_INET)
           transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(lambda: BroadcastProtocol(message, self.broadcast_port), local_addr=(addr, 0), family=socket.AF_INET)
-          if DEBUG_DISCOVERY >= 3:
-            print(f"Broadcasting presence at ({addr})")
+          if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr} - {interface_name} - {interface_priority})")
         except Exception as e:
         except Exception as e:
-          print(f"Error in broadcast presence ({addr}): {e}")
+          print(f"Error in broadcast presence ({addr} - {interface_name} - {interface_priority}): {e}")
         finally:
         finally:
           if transport:
           if transport:
-            try:
-              transport.close()
+            try: transport.close()
             except Exception as e:
             except Exception as e:
               if DEBUG_DISCOVERY >= 2: print(f"Error closing transport: {e}")
               if DEBUG_DISCOVERY >= 2: print(f"Error closing transport: {e}")
               if DEBUG_DISCOVERY >= 2: traceback.print_exc()
               if DEBUG_DISCOVERY >= 2: traceback.print_exc()
@@ -144,6 +144,7 @@ class UDPDiscovery(Discovery):
       peer_host = addr[0]
       peer_host = addr[0]
       peer_port = message["grpc_port"]
       peer_port = message["grpc_port"]
       peer_prio = message["priority"]
       peer_prio = message["priority"]
+      peer_interface_name = message["interface_name"]
       device_capabilities = DeviceCapabilities(**message["device_capabilities"])
       device_capabilities = DeviceCapabilities(**message["device_capabilities"])
 
 
       if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
       if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
@@ -153,7 +154,7 @@ class UDPDiscovery(Discovery):
             if DEBUG >= 1:
             if DEBUG >= 1:
               print(f"Ignoring peer {peer_id} at {peer_host}:{peer_port} with priority {peer_prio} because we already know about a peer with higher or equal priority: {existing_peer_prio}")
               print(f"Ignoring peer {peer_id} at {peer_host}:{peer_port} with priority {peer_prio} because we already know about a peer with higher or equal priority: {existing_peer_prio}")
             return
             return
-        new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
+        new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", peer_interface_name, device_capabilities)
         if not await new_peer_handle.health_check():
         if not await new_peer_handle.health_check():
           if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} is not healthy. Skipping.")
           if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} is not healthy. Skipping.")
           return
           return

+ 1 - 1
exo/orchestration/standard_node.py

@@ -398,7 +398,7 @@ class StandardNode(Node):
 
 
     for peer in self.peers:
     for peer in self.peers:
       next_topology.update_node(peer.id(), peer.device_capabilities())
       next_topology.update_node(peer.id(), peer.device_capabilities())
-      next_topology.add_edge(self.id, peer.id())
+      next_topology.add_edge(self.id, peer.id(), peer.description())
 
 
       if peer.id() in prev_visited:
       if peer.id() in prev_visited:
         continue
         continue

+ 40 - 15
exo/topology/topology.py

@@ -1,11 +1,28 @@
 from .device_capabilities import DeviceCapabilities
 from .device_capabilities import DeviceCapabilities
-from typing import Dict, Set, Optional
+from typing import Dict, Set, Optional, NamedTuple
+from dataclasses import dataclass
 
 
+@dataclass
+class PeerConnection:
+  from_id: str
+  to_id: str
+  description: Optional[str] = None
+
+  def __hash__(self):
+    # Use both from_id and to_id for uniqueness in sets
+    return hash((self.from_id, self.to_id))
+
+  def __eq__(self, other):
+    if not isinstance(other, PeerConnection):
+      return False
+    # Compare both from_id and to_id for equality
+    return self.from_id == other.from_id and self.to_id == other.to_id
 
 
 class Topology:
 class Topology:
   def __init__(self):
   def __init__(self):
-    self.nodes: Dict[str, DeviceCapabilities] = {}  # Maps node IDs to DeviceCapabilities
-    self.peer_graph: Dict[str, Set[str]] = {}  # Adjacency list representing the graph
+    self.nodes: Dict[str, DeviceCapabilities] = {}
+    # Store PeerConnection objects in the adjacency lists
+    self.peer_graph: Dict[str, Set[PeerConnection]] = {}
     self.active_node_id: Optional[str] = None
     self.active_node_id: Optional[str] = None
 
 
   def update_node(self, node_id: str, device_capabilities: DeviceCapabilities):
   def update_node(self, node_id: str, device_capabilities: DeviceCapabilities):
@@ -17,33 +34,41 @@ class Topology:
   def all_nodes(self):
   def all_nodes(self):
     return self.nodes.items()
     return self.nodes.items()
 
 
-  def add_edge(self, node1_id: str, node2_id: str):
+  def add_edge(self, node1_id: str, node2_id: str, description: Optional[str] = None):
     if node1_id not in self.peer_graph:
     if node1_id not in self.peer_graph:
       self.peer_graph[node1_id] = set()
       self.peer_graph[node1_id] = set()
     if node2_id not in self.peer_graph:
     if node2_id not in self.peer_graph:
       self.peer_graph[node2_id] = set()
       self.peer_graph[node2_id] = set()
-    self.peer_graph[node1_id].add(node2_id)
-    self.peer_graph[node2_id].add(node1_id)
+
+    # Create bidirectional connections with the same description
+    conn1 = PeerConnection(node1_id, node2_id, description)
+    conn2 = PeerConnection(node2_id, node1_id, description)
+
+    self.peer_graph[node1_id].add(conn1)
+    self.peer_graph[node2_id].add(conn2)
 
 
   def get_neighbors(self, node_id: str) -> Set[str]:
   def get_neighbors(self, node_id: str) -> Set[str]:
-    return self.peer_graph.get(node_id, set())
+    # Convert PeerConnection objects back to just destination IDs
+    return {conn.to_id for conn in self.peer_graph.get(node_id, set())}
 
 
   def all_edges(self):
   def all_edges(self):
     edges = []
     edges = []
-    for node, neighbors in self.peer_graph.items():
-      for neighbor in neighbors:
-        if (neighbor, node) not in edges:  # Avoid duplicate edges
-          edges.append((node, neighbor))
+    for node_id, connections in self.peer_graph.items():
+      for conn in connections:
+        # Only include each edge once by checking if reverse already exists
+        if not any(e[0] == conn.to_id and e[1] == conn.from_id for e in edges):
+          edges.append((conn.from_id, conn.to_id, conn.description))
     return edges
     return edges
 
 
   def merge(self, other: "Topology"):
   def merge(self, other: "Topology"):
     for node_id, capabilities in other.nodes.items():
     for node_id, capabilities in other.nodes.items():
       self.update_node(node_id, capabilities)
       self.update_node(node_id, capabilities)
-    for node_id, neighbors in other.peer_graph.items():
-      for neighbor in neighbors:
-        self.add_edge(node_id, neighbor)
+    for node_id, connections in other.peer_graph.items():
+      for conn in connections:
+        self.add_edge(conn.from_id, conn.to_id, conn.description)
 
 
   def __str__(self):
   def __str__(self):
     nodes_str = ", ".join(f"{node_id}: {cap}" for node_id, cap in self.nodes.items())
     nodes_str = ", ".join(f"{node_id}: {cap}" for node_id, cap in self.nodes.items())
-    edges_str = ", ".join(f"{node}: {neighbors}" for node, neighbors in self.peer_graph.items())
+    edges_str = ", ".join(f"{node}: {[f'{c.to_id}({c.description})' for c in conns]}"
+                         for node, conns in self.peer_graph.items())
     return f"Topology(Nodes: {{{nodes_str}}}, Edges: {{{edges_str}}})"
     return f"Topology(Nodes: {{{nodes_str}}}, Edges: {{{edges_str}}})"

+ 15 - 1
exo/viz/topology_viz.py

@@ -242,12 +242,19 @@ class TopologyViz:
             if info_y + j != y or info_x + k != x:
             if info_y + j != y or info_x + k != x:
               visualization[info_y + j][info_x + k] = char
               visualization[info_y + j][info_x + k] = char
 
 
-      # Draw line to next node
+      # Draw line to next node and add connection description
       next_i = (i+1) % num_partitions
       next_i = (i+1) % num_partitions
       next_angle = 2*math.pi*next_i/num_partitions
       next_angle = 2*math.pi*next_i/num_partitions
       next_x = int(center_x + radius_x*math.cos(next_angle))
       next_x = int(center_x + radius_x*math.cos(next_angle))
       next_y = int(center_y + radius_y*math.sin(next_angle))
       next_y = int(center_y + radius_y*math.sin(next_angle))
 
 
+      # Get connection descriptions
+      conn1 = self.topology.peer_graph.get(partition.node_id, set())
+      conn2 = self.topology.peer_graph.get(self.partitions[next_i].node_id, set())
+      description1 = next((c.description for c in conn1 if c.to_id == self.partitions[next_i].node_id), "")
+      description2 = next((c.description for c in conn2 if c.to_id == partition.node_id), "")
+      connection_description = f"{description1}/{description2}" if description1 != description2 else description1
+
       # Simple line drawing
       # Simple line drawing
       steps = max(abs(next_x - x), abs(next_y - y))
       steps = max(abs(next_x - x), abs(next_y - y))
       for step in range(1, steps):
       for step in range(1, steps):
@@ -256,6 +263,13 @@ class TopologyViz:
         if 0 <= line_y < 48 and 0 <= line_x < 100:
         if 0 <= line_y < 48 and 0 <= line_x < 100:
           visualization[line_y][line_x] = "-"
           visualization[line_y][line_x] = "-"
 
 
+      # Add connection description near the midpoint of the line
+      mid_x = (x + next_x) // 2
+      mid_y = (y + next_y) // 2
+      for j, char in enumerate(connection_description):
+        if 0 <= mid_y < 48 and 0 <= mid_x + j < 100:
+          visualization[mid_y][mid_x + j] = char
+
     # Convert to string
     # Convert to string
     return "\n".join("".join(str(char) for char in row) for row in visualization)
     return "\n".join("".join(str(char) for char in row) for row in visualization)
 
 

Some files were not shown because too many files changed in this diff