Browse Source

prioritise network interfaces and display in the tui

Alex Cheema 6 months ago
parent
commit
b8c4b46fe9

+ 31 - 3
exo/helpers.py

@@ -222,7 +222,7 @@ def pretty_print_bytes_per_second(bytes_per_second: int) -> str:
     return f"{bytes_per_second / (1024 ** 4):.2f} TB/s"
 
 
-def get_all_ip_addresses():
+def get_all_ip_addresses_and_interfaces():
   try:
     ip_addresses = []
     for interface in netifaces.interfaces():
@@ -230,12 +230,40 @@ def get_all_ip_addresses():
       if netifaces.AF_INET in ifaddresses:
         for link in ifaddresses[netifaces.AF_INET]:
           ip = link['addr']
-          ip_addresses.append(ip)
+          ip_addresses.append((ip, interface))
     return list(set(ip_addresses))
   except:
     if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
-    return ["localhost"]
+    return [("localhost", "lo")]
 
+def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
+  # Local container/virtual interfaces
+  if (ifname.startswith(('docker', 'br-', 'veth', 'cni', 'flannel', 'calico', 'weave')) or
+    'bridge' in ifname):
+    return (7, "Container Virtual")
+
+  # Loopback interface
+  if ifname.startswith('lo'):
+    return (6, "Loopback")
+
+  # Thunderbolt/10GbE detection
+  if ifname.startswith(('tb', 'nx', 'ten')):
+    return (5, "Thunderbolt/10GbE")
+
+  # Regular ethernet detection
+  if ifname.startswith(('eth', 'en')) and not ifname.startswith(('en1', 'en0')):
+    return (4, "Ethernet")
+
+  # WiFi detection
+  if ifname.startswith(('wlan', 'wifi', 'wl')) or ifname in ['en0', 'en1']:
+    return (3, "WiFi")
+
+  # Non-local virtual interfaces (VPNs, tunnels)
+  if ifname.startswith(('tun', 'tap', 'vtun', 'utun', 'gif', 'stf', 'awdl', 'llw')):
+    return (1, "External Virtual")
+
+  # Other physical interfaces
+  return (2, "Other")
 
 async def shutdown(signal, loop, server):
   """Gracefully shutdown the server and close the asyncio loop."""

+ 6 - 6
exo/main.py

@@ -21,7 +21,7 @@ from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWe
 from exo.api import ChatGPTAPI
 from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
 from exo.download.hf.hf_shard_download import HFShardDownloader
-from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link, shutdown
+from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses_and_interfaces, terminal_link, shutdown
 from exo.inference.shard import Shard
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
 from exo.inference.tokenizers import resolve_tokenizer
@@ -80,8 +80,8 @@ if args.node_port is None:
   if DEBUG >= 1: print(f"Using available port: {args.node_port}")
 
 args.node_id = args.node_id or get_or_create_node_id()
-chatgpt_api_endpoints = [f"http://{ip}:{args.chatgpt_api_port}/v1/chat/completions" for ip in get_all_ip_addresses()]
-web_chat_urls = [f"http://{ip}:{args.chatgpt_api_port}" for ip in get_all_ip_addresses()]
+chatgpt_api_endpoints = [f"http://{ip}:{args.chatgpt_api_port}/v1/chat/completions" for ip, _ in get_all_ip_addresses_and_interfaces()]
+web_chat_urls = [f"http://{ip}:{args.chatgpt_api_port}" for ip, _ in get_all_ip_addresses_and_interfaces()]
 if DEBUG >= 0:
   print("Chat interface started:")
   for web_chat_url in web_chat_urls:
@@ -99,7 +99,7 @@ if args.discovery_module == "udp":
     args.node_port,
     args.listen_port,
     args.broadcast_port,
-    lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities),
+    lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
     discovery_timeout=args.discovery_timeout,
     allowed_node_ids=allowed_node_ids
   )
@@ -107,7 +107,7 @@ elif args.discovery_module == "tailscale":
   discovery = TailscaleDiscovery(
     args.node_id,
     args.node_port,
-    lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities),
+    lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
     discovery_timeout=args.discovery_timeout,
     tailscale_api_key=args.tailscale_api_key,
     tailnet=args.tailnet_name,
@@ -116,7 +116,7 @@ elif args.discovery_module == "tailscale":
 elif args.discovery_module == "manual":
   if not args.discovery_config_path:
     raise ValueError(f"--discovery-config-path is required when using manual discovery. Please provide a path to a config json file.")
-  discovery = ManualDiscovery(args.discovery_config_path, args.node_id, create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
+  discovery = ManualDiscovery(args.discovery_config_path, args.node_id, create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
 topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None
 node = StandardNode(
   args.node_id,

+ 12 - 5
exo/networking/grpc/grpc_peer_handle.py

@@ -14,9 +14,10 @@ from exo.helpers import DEBUG
 
 
 class GRPCPeerHandle(PeerHandle):
-  def __init__(self, _id: str, address: str, device_capabilities: DeviceCapabilities):
+  def __init__(self, _id: str, address: str, desc: str, device_capabilities: DeviceCapabilities):
     self._id = _id
     self.address = address
+    self.desc = desc
     self._device_capabilities = device_capabilities
     self.channel = None
     self.stub = None
@@ -27,6 +28,9 @@ class GRPCPeerHandle(PeerHandle):
   def addr(self) -> str:
     return self.address
 
+  def description(self) -> str:
+    return self.desc
+
   def device_capabilities(self) -> DeviceCapabilities:
     return self._device_capabilities
 
@@ -119,12 +123,15 @@ class GRPCPeerHandle(PeerHandle):
     topology = Topology()
     for node_id, capabilities in response.nodes.items():
       device_capabilities = DeviceCapabilities(
-        model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
+        model=capabilities.model,
+        chip=capabilities.chip,
+        memory=capabilities.memory,
+        flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
       )
       topology.update_node(node_id, device_capabilities)
-    for node_id, peers in response.peer_graph.items():
-      for peer_id in peers.peer_ids:
-        topology.add_edge(node_id, peer_id)
+    for node_id, peer_connections in response.peer_graph.items():
+      for conn in peer_connections.connections:
+        topology.add_edge(node_id, conn.to_id, conn.description)
     return topology
 
   async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:

+ 9 - 1
exo/networking/grpc/grpc_server.py

@@ -96,7 +96,15 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
         )
       for node_id, cap in topology.nodes.items()
     }
-    peer_graph = {node_id: node_service_pb2.Peers(peer_ids=peers) for node_id, peers in topology.peer_graph.items()}
+    peer_graph = {
+      node_id: node_service_pb2.PeerConnections(
+        connections=[
+          node_service_pb2.PeerConnection(to_id=conn.to_id, description=conn.description)
+          for conn in connections
+        ]
+      )
+      for node_id, connections in topology.peer_graph.items()
+    }
     if DEBUG >= 5: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
     return node_service_pb2.Topology(nodes=nodes, peer_graph=peer_graph)
 

+ 8 - 3
exo/networking/grpc/node_service.proto

@@ -53,11 +53,16 @@ message CollectTopologyRequest {
 
 message Topology {
   map<string, DeviceCapabilities> nodes = 1;
-  map<string, Peers> peer_graph = 2;
+  map<string, PeerConnections> peer_graph = 2;
 }
 
-message Peers {
-    repeated string peer_ids = 1;
+message PeerConnection {
+  string to_id = 1;
+  optional string description = 2;
+}
+
+message PeerConnections {
+  repeated PeerConnection connections = 1;
 }
 
 message DeviceFlops {

File diff suppressed because it is too large
+ 11 - 1
exo/networking/grpc/node_service_pb2.py


+ 2 - 7
exo/networking/grpc/node_service_pb2_grpc.py

@@ -5,10 +5,8 @@ import warnings
 
 from . import node_service_pb2 as node__service__pb2
 
-GRPC_GENERATED_VERSION = '1.64.1'
+GRPC_GENERATED_VERSION = '1.68.0'
 GRPC_VERSION = grpc.__version__
-EXPECTED_ERROR_RELEASE = '1.65.0'
-SCHEDULED_RELEASE_DATE = 'June 25, 2024'
 _version_not_supported = False
 
 try:
@@ -18,15 +16,12 @@ except ImportError:
     _version_not_supported = True
 
 if _version_not_supported:
-    warnings.warn(
+    raise RuntimeError(
         f'The grpc package installed is at version {GRPC_VERSION},'
         + f' but the generated code in node_service_pb2_grpc.py depends on'
         + f' grpcio>={GRPC_GENERATED_VERSION}.'
         + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
         + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
-        + f' This warning will become an error in {EXPECTED_ERROR_RELEASE},'
-        + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.',
-        RuntimeWarning
     )
 
 

+ 1 - 1
exo/networking/manual/manual_discovery.py

@@ -53,7 +53,7 @@ class ManualDiscovery(Discovery):
           peer = self.known_peers.get(peer_id)
           if not peer:
             if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} not found in known peers. Adding.")
-            peer = self.create_peer_handle(peer_id, f"{peer_config.address}:{peer_config.port}", peer_config.device_capabilities)
+            peer = self.create_peer_handle(peer_id, f"{peer_config.address}:{peer_config.port}", "MAN", peer_config.device_capabilities)
           is_healthy = await peer.health_check()
           if is_healthy:
             if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")

+ 5 - 5
exo/networking/manual/test_manual_discovery.py

@@ -14,7 +14,7 @@ class TestSingleNodeManualDiscovery(unittest.IsolatedAsyncioTestCase):
   async def asyncSetUp(self):
     self.peer1 = mock.AsyncMock()
     self.peer1.connect = mock.AsyncMock()
-    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, device_capabilities: self.peer1)
+    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
     _ = self.discovery1.start()
 
   async def asyncTearDown(self):
@@ -33,8 +33,8 @@ class TestManualDiscovery(unittest.IsolatedAsyncioTestCase):
     self.peer2 = mock.AsyncMock()
     self.peer1.connect = mock.AsyncMock()
     self.peer2.connect = mock.AsyncMock()
-    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, device_capabilities: self.peer1)
-    self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, device_capabilities: self.peer2)
+    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
+    self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2)
     await self.discovery1.start()
     await self.discovery2.start()
 
@@ -63,8 +63,8 @@ class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
     self.server2 = GRPCServer(self.node2, config.peers["node2"].address, config.peers["node2"].port)
     await self.server1.start()
     await self.server2.start()
-    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
-    self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
+    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
+    self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
     await self.discovery1.start()
     await self.discovery2.start()
 

+ 4 - 0
exo/networking/peer_handle.py

@@ -15,6 +15,10 @@ class PeerHandle(ABC):
   def addr(self) -> str:
     pass
 
+  @abstractmethod
+  def description(self) -> str:
+    pass
+
   @abstractmethod
   def device_capabilities(self) -> DeviceCapabilities:
     pass

+ 2 - 2
exo/networking/tailscale/tailscale_discovery.py

@@ -14,7 +14,7 @@ class TailscaleDiscovery(Discovery):
     self,
     node_id: str,
     node_port: int,
-    create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
+    create_peer_handle: Callable[[str, str, str, DeviceCapabilities], PeerHandle],
     discovery_interval: int = 5,
     discovery_timeout: int = 30,
     update_interval: int = 15,
@@ -91,7 +91,7 @@ class TailscaleDiscovery(Discovery):
             continue
 
           if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
-            new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
+            new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", "TS", device_capabilities)
             if not await new_peer_handle.health_check():
               if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} is not healthy. Skipping.")
               continue

+ 1 - 1
exo/networking/tailscale/test_tailscale_discovery.py

@@ -13,7 +13,7 @@ class TestTailscaleDiscovery(unittest.IsolatedAsyncioTestCase):
     self.discovery = TailscaleDiscovery(
       node_id="test_node",
       node_port=50051,
-      create_peer_handle=lambda peer_id, address, device_capabilities: unittest.mock.Mock(spec=PeerHandle, id=lambda: peer_id),
+      create_peer_handle=lambda peer_id, address, description, device_capabilities: unittest.mock.Mock(spec=PeerHandle, id=lambda: peer_id),
       tailscale_api_key=self.tailscale_api_key,
       tailnet=self.tailnet
     )

+ 4 - 4
exo/networking/udp/test_udp_discovery.py

@@ -13,8 +13,8 @@ class TestUDPDiscovery(unittest.IsolatedAsyncioTestCase):
     self.peer2 = mock.AsyncMock()
     self.peer1.connect = mock.AsyncMock()
     self.peer2.connect = mock.AsyncMock()
-    self.discovery1 = UDPDiscovery("discovery1", 50051, 5678, 5679, create_peer_handle=lambda peer_id, address, device_capabilities: self.peer1)
-    self.discovery2 = UDPDiscovery("discovery2", 50052, 5679, 5678, create_peer_handle=lambda peer_id, address, device_capabilities: self.peer2)
+    self.discovery1 = UDPDiscovery("discovery1", 50051, 5678, 5679, create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
+    self.discovery2 = UDPDiscovery("discovery2", 50052, 5679, 5678, create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2)
     await self.discovery1.start()
     await self.discovery2.start()
 
@@ -41,8 +41,8 @@ class TestUDPDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
     self.server2 = GRPCServer(self.node2, "localhost", 50054)
     await self.server1.start()
     await self.server2.start()
-    self.discovery1 = UDPDiscovery("discovery1", 50053, 5678, 5679, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
-    self.discovery2 = UDPDiscovery("discovery2", 50054, 5679, 5678, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
+    self.discovery1 = UDPDiscovery("discovery1", 50053, 5678, 5679, lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
+    self.discovery2 = UDPDiscovery("discovery2", 50054, 5679, 5678, lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
     await self.discovery1.start()
     await self.discovery2.start()
 

+ 12 - 11
exo/networking/udp/udp_discovery.py

@@ -7,7 +7,7 @@ from typing import List, Dict, Callable, Tuple, Coroutine
 from exo.networking.discovery import Discovery
 from exo.networking.peer_handle import PeerHandle
 from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
-from exo.helpers import DEBUG, DEBUG_DISCOVERY, get_all_ip_addresses
+from exo.helpers import DEBUG, DEBUG_DISCOVERY, get_all_ip_addresses_and_interfaces, get_interface_priority_and_type
 
 
 class ListenProtocol(asyncio.DatagramProtocol):
@@ -41,7 +41,7 @@ class UDPDiscovery(Discovery):
     node_port: int,
     listen_port: int,
     broadcast_port: int,
-    create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
+    create_peer_handle: Callable[[str, str, str, DeviceCapabilities], PeerHandle],
     broadcast_interval: int = 1,
     discovery_timeout: int = 30,
     device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
@@ -87,27 +87,27 @@ class UDPDiscovery(Discovery):
     while True:
       # Explicitly broadcasting on all assigned ips since broadcasting on `0.0.0.0` on MacOS does not broadcast over
       # the Thunderbolt bridge when other connection modalities exist such as WiFi or Ethernet
-      for addr in get_all_ip_addresses():
+      for addr, interface_name in get_all_ip_addresses_and_interfaces():
+        interface_priority, _ = get_interface_priority_and_type(interface_name)
         message = json.dumps({
           "type": "discovery",
           "node_id": self.node_id,
           "grpc_port": self.node_port,
           "device_capabilities": self.device_capabilities.to_dict(),
-          "priority": 1,  # For now, every interface has the same priority. We can make this better by prioriting interfaces based on bandwidth, latency, and jitter e.g. prioritise Thunderbolt over WiFi.
+          "priority": interface_priority, # TODO: Prioritise interfaces based on bandwidth, latency, and jitter e.g. prioritise Thunderbolt over WiFi.
+          "interface_name": interface_name,
         })
-        if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr}): {message}")
+        if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr} - {interface_name} - {interface_priority}): {message}")
 
         transport = None
         try:
           transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(lambda: BroadcastProtocol(message, self.broadcast_port), local_addr=(addr, 0), family=socket.AF_INET)
-          if DEBUG_DISCOVERY >= 3:
-            print(f"Broadcasting presence at ({addr})")
+          if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr} - {interface_name} - {interface_priority})")
         except Exception as e:
-          print(f"Error in broadcast presence ({addr}): {e}")
+          print(f"Error in broadcast presence ({addr} - {interface_name} - {interface_priority}): {e}")
         finally:
           if transport:
-            try:
-              transport.close()
+            try: transport.close()
             except Exception as e:
               if DEBUG_DISCOVERY >= 2: print(f"Error closing transport: {e}")
               if DEBUG_DISCOVERY >= 2: traceback.print_exc()
@@ -144,6 +144,7 @@ class UDPDiscovery(Discovery):
       peer_host = addr[0]
       peer_port = message["grpc_port"]
       peer_prio = message["priority"]
+      peer_interface_name = message["interface_name"]
       device_capabilities = DeviceCapabilities(**message["device_capabilities"])
 
       if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
@@ -153,7 +154,7 @@ class UDPDiscovery(Discovery):
             if DEBUG >= 1:
               print(f"Ignoring peer {peer_id} at {peer_host}:{peer_port} with priority {peer_prio} because we already know about a peer with higher or equal priority: {existing_peer_prio}")
             return
-        new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
+        new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", peer_interface_name, device_capabilities)
         if not await new_peer_handle.health_check():
           if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} is not healthy. Skipping.")
           return

+ 1 - 1
exo/orchestration/standard_node.py

@@ -398,7 +398,7 @@ class StandardNode(Node):
 
     for peer in self.peers:
       next_topology.update_node(peer.id(), peer.device_capabilities())
-      next_topology.add_edge(self.id, peer.id())
+      next_topology.add_edge(self.id, peer.id(), peer.description())
 
       if peer.id() in prev_visited:
         continue

+ 40 - 15
exo/topology/topology.py

@@ -1,11 +1,28 @@
 from .device_capabilities import DeviceCapabilities
-from typing import Dict, Set, Optional
+from typing import Dict, Set, Optional, NamedTuple
+from dataclasses import dataclass
 
+@dataclass
+class PeerConnection:
+  from_id: str
+  to_id: str
+  description: Optional[str] = None
+
+  def __hash__(self):
+    # Use both from_id and to_id for uniqueness in sets
+    return hash((self.from_id, self.to_id))
+
+  def __eq__(self, other):
+    if not isinstance(other, PeerConnection):
+      return False
+    # Compare both from_id and to_id for equality
+    return self.from_id == other.from_id and self.to_id == other.to_id
 
 class Topology:
   def __init__(self):
-    self.nodes: Dict[str, DeviceCapabilities] = {}  # Maps node IDs to DeviceCapabilities
-    self.peer_graph: Dict[str, Set[str]] = {}  # Adjacency list representing the graph
+    self.nodes: Dict[str, DeviceCapabilities] = {}
+    # Store PeerConnection objects in the adjacency lists
+    self.peer_graph: Dict[str, Set[PeerConnection]] = {}
     self.active_node_id: Optional[str] = None
 
   def update_node(self, node_id: str, device_capabilities: DeviceCapabilities):
@@ -17,33 +34,41 @@ class Topology:
   def all_nodes(self):
     return self.nodes.items()
 
-  def add_edge(self, node1_id: str, node2_id: str):
+  def add_edge(self, node1_id: str, node2_id: str, description: Optional[str] = None):
     if node1_id not in self.peer_graph:
       self.peer_graph[node1_id] = set()
     if node2_id not in self.peer_graph:
       self.peer_graph[node2_id] = set()
-    self.peer_graph[node1_id].add(node2_id)
-    self.peer_graph[node2_id].add(node1_id)
+
+    # Create bidirectional connections with the same description
+    conn1 = PeerConnection(node1_id, node2_id, description)
+    conn2 = PeerConnection(node2_id, node1_id, description)
+
+    self.peer_graph[node1_id].add(conn1)
+    self.peer_graph[node2_id].add(conn2)
 
   def get_neighbors(self, node_id: str) -> Set[str]:
-    return self.peer_graph.get(node_id, set())
+    # Convert PeerConnection objects back to just destination IDs
+    return {conn.to_id for conn in self.peer_graph.get(node_id, set())}
 
   def all_edges(self):
     edges = []
-    for node, neighbors in self.peer_graph.items():
-      for neighbor in neighbors:
-        if (neighbor, node) not in edges:  # Avoid duplicate edges
-          edges.append((node, neighbor))
+    for node_id, connections in self.peer_graph.items():
+      for conn in connections:
+        # Only include each edge once by checking if reverse already exists
+        if not any(e[0] == conn.to_id and e[1] == conn.from_id for e in edges):
+          edges.append((conn.from_id, conn.to_id, conn.description))
     return edges
 
   def merge(self, other: "Topology"):
     for node_id, capabilities in other.nodes.items():
       self.update_node(node_id, capabilities)
-    for node_id, neighbors in other.peer_graph.items():
-      for neighbor in neighbors:
-        self.add_edge(node_id, neighbor)
+    for node_id, connections in other.peer_graph.items():
+      for conn in connections:
+        self.add_edge(conn.from_id, conn.to_id, conn.description)
 
   def __str__(self):
     nodes_str = ", ".join(f"{node_id}: {cap}" for node_id, cap in self.nodes.items())
-    edges_str = ", ".join(f"{node}: {neighbors}" for node, neighbors in self.peer_graph.items())
+    edges_str = ", ".join(f"{node}: {[f'{c.to_id}({c.description})' for c in conns]}"
+                         for node, conns in self.peer_graph.items())
     return f"Topology(Nodes: {{{nodes_str}}}, Edges: {{{edges_str}}})"

+ 15 - 1
exo/viz/topology_viz.py

@@ -242,12 +242,19 @@ class TopologyViz:
             if info_y + j != y or info_x + k != x:
               visualization[info_y + j][info_x + k] = char
 
-      # Draw line to next node
+      # Draw line to next node and add connection description
       next_i = (i+1) % num_partitions
       next_angle = 2*math.pi*next_i/num_partitions
       next_x = int(center_x + radius_x*math.cos(next_angle))
       next_y = int(center_y + radius_y*math.sin(next_angle))
 
+      # Get connection descriptions
+      conn1 = self.topology.peer_graph.get(partition.node_id, set())
+      conn2 = self.topology.peer_graph.get(self.partitions[next_i].node_id, set())
+      description1 = next((c.description for c in conn1 if c.to_id == self.partitions[next_i].node_id), "")
+      description2 = next((c.description for c in conn2 if c.to_id == partition.node_id), "")
+      connection_description = f"{description1}/{description2}" if description1 != description2 else description1
+
       # Simple line drawing
       steps = max(abs(next_x - x), abs(next_y - y))
       for step in range(1, steps):
@@ -256,6 +263,13 @@ class TopologyViz:
         if 0 <= line_y < 48 and 0 <= line_x < 100:
           visualization[line_y][line_x] = "-"
 
+      # Add connection description near the midpoint of the line
+      mid_x = (x + next_x) // 2
+      mid_y = (y + next_y) // 2
+      for j, char in enumerate(connection_description):
+        if 0 <= mid_y < 48 and 0 <= mid_x + j < 100:
+          visualization[mid_y][mid_x + j] = char
+
     # Convert to string
     return "\n".join("".join(str(char) for char in row) for row in visualization)
 

Some files were not shown because too many files changed in this diff