Browse Source

prioritise network interfaces and display in the tui

Alex Cheema 8 tháng trước cách đây
mục cha
commit
b8c4b46fe9

+ 31 - 3
exo/helpers.py

@@ -222,7 +222,7 @@ def pretty_print_bytes_per_second(bytes_per_second: int) -> str:
     return f"{bytes_per_second / (1024 ** 4):.2f} TB/s"
 
 
-def get_all_ip_addresses():
+def get_all_ip_addresses_and_interfaces():
   try:
     ip_addresses = []
     for interface in netifaces.interfaces():
@@ -230,12 +230,40 @@ def get_all_ip_addresses():
       if netifaces.AF_INET in ifaddresses:
         for link in ifaddresses[netifaces.AF_INET]:
           ip = link['addr']
-          ip_addresses.append(ip)
+          ip_addresses.append((ip, interface))
     return list(set(ip_addresses))
   except:
     if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
-    return ["localhost"]
+    return [("localhost", "lo")]
 
+def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
+  # Local container/virtual interfaces
+  if (ifname.startswith(('docker', 'br-', 'veth', 'cni', 'flannel', 'calico', 'weave')) or
+    'bridge' in ifname):
+    return (7, "Container Virtual")
+
+  # Loopback interface
+  if ifname.startswith('lo'):
+    return (6, "Loopback")
+
+  # Thunderbolt/10GbE detection
+  if ifname.startswith(('tb', 'nx', 'ten')):
+    return (5, "Thunderbolt/10GbE")
+
+  # Regular ethernet detection
+  if ifname.startswith(('eth', 'en')) and not ifname.startswith(('en1', 'en0')):
+    return (4, "Ethernet")
+
+  # WiFi detection
+  if ifname.startswith(('wlan', 'wifi', 'wl')) or ifname in ['en0', 'en1']:
+    return (3, "WiFi")
+
+  # Non-local virtual interfaces (VPNs, tunnels)
+  if ifname.startswith(('tun', 'tap', 'vtun', 'utun', 'gif', 'stf', 'awdl', 'llw')):
+    return (1, "External Virtual")
+
+  # Other physical interfaces
+  return (2, "Other")
 
 async def shutdown(signal, loop, server):
   """Gracefully shutdown the server and close the asyncio loop."""

+ 6 - 6
exo/main.py

@@ -21,7 +21,7 @@ from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWe
 from exo.api import ChatGPTAPI
 from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
 from exo.download.hf.hf_shard_download import HFShardDownloader
-from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link, shutdown
+from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses_and_interfaces, terminal_link, shutdown
 from exo.inference.shard import Shard
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
 from exo.inference.tokenizers import resolve_tokenizer
@@ -80,8 +80,8 @@ if args.node_port is None:
   if DEBUG >= 1: print(f"Using available port: {args.node_port}")
 
 args.node_id = args.node_id or get_or_create_node_id()
-chatgpt_api_endpoints = [f"http://{ip}:{args.chatgpt_api_port}/v1/chat/completions" for ip in get_all_ip_addresses()]
-web_chat_urls = [f"http://{ip}:{args.chatgpt_api_port}" for ip in get_all_ip_addresses()]
+chatgpt_api_endpoints = [f"http://{ip}:{args.chatgpt_api_port}/v1/chat/completions" for ip, _ in get_all_ip_addresses_and_interfaces()]
+web_chat_urls = [f"http://{ip}:{args.chatgpt_api_port}" for ip, _ in get_all_ip_addresses_and_interfaces()]
 if DEBUG >= 0:
   print("Chat interface started:")
   for web_chat_url in web_chat_urls:
@@ -99,7 +99,7 @@ if args.discovery_module == "udp":
     args.node_port,
     args.listen_port,
     args.broadcast_port,
-    lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities),
+    lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
     discovery_timeout=args.discovery_timeout,
     allowed_node_ids=allowed_node_ids
   )
@@ -107,7 +107,7 @@ elif args.discovery_module == "tailscale":
   discovery = TailscaleDiscovery(
     args.node_id,
     args.node_port,
-    lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities),
+    lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
     discovery_timeout=args.discovery_timeout,
     tailscale_api_key=args.tailscale_api_key,
     tailnet=args.tailnet_name,
@@ -116,7 +116,7 @@ elif args.discovery_module == "tailscale":
 elif args.discovery_module == "manual":
   if not args.discovery_config_path:
     raise ValueError(f"--discovery-config-path is required when using manual discovery. Please provide a path to a config json file.")
-  discovery = ManualDiscovery(args.discovery_config_path, args.node_id, create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
+  discovery = ManualDiscovery(args.discovery_config_path, args.node_id, create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
 topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None
 node = StandardNode(
   args.node_id,

+ 12 - 5
exo/networking/grpc/grpc_peer_handle.py

@@ -14,9 +14,10 @@ from exo.helpers import DEBUG
 
 
 class GRPCPeerHandle(PeerHandle):
-  def __init__(self, _id: str, address: str, device_capabilities: DeviceCapabilities):
+  def __init__(self, _id: str, address: str, desc: str, device_capabilities: DeviceCapabilities):
     self._id = _id
     self.address = address
+    self.desc = desc
     self._device_capabilities = device_capabilities
     self.channel = None
     self.stub = None
@@ -27,6 +28,9 @@ class GRPCPeerHandle(PeerHandle):
   def addr(self) -> str:
     return self.address
 
+  def description(self) -> str:
+    return self.desc
+
   def device_capabilities(self) -> DeviceCapabilities:
     return self._device_capabilities
 
@@ -119,12 +123,15 @@ class GRPCPeerHandle(PeerHandle):
     topology = Topology()
     for node_id, capabilities in response.nodes.items():
       device_capabilities = DeviceCapabilities(
-        model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
+        model=capabilities.model,
+        chip=capabilities.chip,
+        memory=capabilities.memory,
+        flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
       )
       topology.update_node(node_id, device_capabilities)
-    for node_id, peers in response.peer_graph.items():
-      for peer_id in peers.peer_ids:
-        topology.add_edge(node_id, peer_id)
+    for node_id, peer_connections in response.peer_graph.items():
+      for conn in peer_connections.connections:
+        topology.add_edge(node_id, conn.to_id, conn.description)
     return topology
 
   async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:

+ 9 - 1
exo/networking/grpc/grpc_server.py

@@ -96,7 +96,15 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
         )
       for node_id, cap in topology.nodes.items()
     }
-    peer_graph = {node_id: node_service_pb2.Peers(peer_ids=peers) for node_id, peers in topology.peer_graph.items()}
+    peer_graph = {
+      node_id: node_service_pb2.PeerConnections(
+        connections=[
+          node_service_pb2.PeerConnection(to_id=conn.to_id, description=conn.description)
+          for conn in connections
+        ]
+      )
+      for node_id, connections in topology.peer_graph.items()
+    }
     if DEBUG >= 5: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
     return node_service_pb2.Topology(nodes=nodes, peer_graph=peer_graph)
 

+ 8 - 3
exo/networking/grpc/node_service.proto

@@ -53,11 +53,16 @@ message CollectTopologyRequest {
 
 message Topology {
   map<string, DeviceCapabilities> nodes = 1;
-  map<string, Peers> peer_graph = 2;
+  map<string, PeerConnections> peer_graph = 2;
 }
 
-message Peers {
-    repeated string peer_ids = 1;
+message PeerConnection {
+  string to_id = 1;
+  optional string description = 2;
+}
+
+message PeerConnections {
+  repeated PeerConnection connections = 1;
 }
 
 message DeviceFlops {

Những thai đổi đã bị hủy bỏ vì nó quá lớn
+ 11 - 1
exo/networking/grpc/node_service_pb2.py


+ 2 - 7
exo/networking/grpc/node_service_pb2_grpc.py

@@ -5,10 +5,8 @@ import warnings
 
 from . import node_service_pb2 as node__service__pb2
 
-GRPC_GENERATED_VERSION = '1.64.1'
+GRPC_GENERATED_VERSION = '1.68.0'
 GRPC_VERSION = grpc.__version__
-EXPECTED_ERROR_RELEASE = '1.65.0'
-SCHEDULED_RELEASE_DATE = 'June 25, 2024'
 _version_not_supported = False
 
 try:
@@ -18,15 +16,12 @@ except ImportError:
     _version_not_supported = True
 
 if _version_not_supported:
-    warnings.warn(
+    raise RuntimeError(
         f'The grpc package installed is at version {GRPC_VERSION},'
         + f' but the generated code in node_service_pb2_grpc.py depends on'
         + f' grpcio>={GRPC_GENERATED_VERSION}.'
         + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
         + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
-        + f' This warning will become an error in {EXPECTED_ERROR_RELEASE},'
-        + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.',
-        RuntimeWarning
     )
 
 

+ 1 - 1
exo/networking/manual/manual_discovery.py

@@ -53,7 +53,7 @@ class ManualDiscovery(Discovery):
           peer = self.known_peers.get(peer_id)
           if not peer:
             if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} not found in known peers. Adding.")
-            peer = self.create_peer_handle(peer_id, f"{peer_config.address}:{peer_config.port}", peer_config.device_capabilities)
+            peer = self.create_peer_handle(peer_id, f"{peer_config.address}:{peer_config.port}", "MAN", peer_config.device_capabilities)
           is_healthy = await peer.health_check()
           if is_healthy:
             if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")

+ 5 - 5
exo/networking/manual/test_manual_discovery.py

@@ -14,7 +14,7 @@ class TestSingleNodeManualDiscovery(unittest.IsolatedAsyncioTestCase):
   async def asyncSetUp(self):
     self.peer1 = mock.AsyncMock()
     self.peer1.connect = mock.AsyncMock()
-    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, device_capabilities: self.peer1)
+    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
     _ = self.discovery1.start()
 
   async def asyncTearDown(self):
@@ -33,8 +33,8 @@ class TestManualDiscovery(unittest.IsolatedAsyncioTestCase):
     self.peer2 = mock.AsyncMock()
     self.peer1.connect = mock.AsyncMock()
     self.peer2.connect = mock.AsyncMock()
-    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, device_capabilities: self.peer1)
-    self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, device_capabilities: self.peer2)
+    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
+    self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2)
     await self.discovery1.start()
     await self.discovery2.start()
 
@@ -63,8 +63,8 @@ class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
     self.server2 = GRPCServer(self.node2, config.peers["node2"].address, config.peers["node2"].port)
     await self.server1.start()
     await self.server2.start()
-    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
-    self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
+    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
+    self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
     await self.discovery1.start()
     await self.discovery2.start()
 

+ 4 - 0
exo/networking/peer_handle.py

@@ -15,6 +15,10 @@ class PeerHandle(ABC):
   def addr(self) -> str:
     pass
 
+  @abstractmethod
+  def description(self) -> str:
+    pass
+
   @abstractmethod
   def device_capabilities(self) -> DeviceCapabilities:
     pass

+ 2 - 2
exo/networking/tailscale/tailscale_discovery.py

@@ -14,7 +14,7 @@ class TailscaleDiscovery(Discovery):
     self,
     node_id: str,
     node_port: int,
-    create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
+    create_peer_handle: Callable[[str, str, str, DeviceCapabilities], PeerHandle],
     discovery_interval: int = 5,
     discovery_timeout: int = 30,
     update_interval: int = 15,
@@ -91,7 +91,7 @@ class TailscaleDiscovery(Discovery):
             continue
 
           if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
-            new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
+            new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", "TS", device_capabilities)
             if not await new_peer_handle.health_check():
               if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} is not healthy. Skipping.")
               continue

+ 1 - 1
exo/networking/tailscale/test_tailscale_discovery.py

@@ -13,7 +13,7 @@ class TestTailscaleDiscovery(unittest.IsolatedAsyncioTestCase):
     self.discovery = TailscaleDiscovery(
       node_id="test_node",
       node_port=50051,
-      create_peer_handle=lambda peer_id, address, device_capabilities: unittest.mock.Mock(spec=PeerHandle, id=lambda: peer_id),
+      create_peer_handle=lambda peer_id, address, description, device_capabilities: unittest.mock.Mock(spec=PeerHandle, id=lambda: peer_id),
       tailscale_api_key=self.tailscale_api_key,
       tailnet=self.tailnet
     )

+ 4 - 4
exo/networking/udp/test_udp_discovery.py

@@ -13,8 +13,8 @@ class TestUDPDiscovery(unittest.IsolatedAsyncioTestCase):
     self.peer2 = mock.AsyncMock()
     self.peer1.connect = mock.AsyncMock()
     self.peer2.connect = mock.AsyncMock()
-    self.discovery1 = UDPDiscovery("discovery1", 50051, 5678, 5679, create_peer_handle=lambda peer_id, address, device_capabilities: self.peer1)
-    self.discovery2 = UDPDiscovery("discovery2", 50052, 5679, 5678, create_peer_handle=lambda peer_id, address, device_capabilities: self.peer2)
+    self.discovery1 = UDPDiscovery("discovery1", 50051, 5678, 5679, create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
+    self.discovery2 = UDPDiscovery("discovery2", 50052, 5679, 5678, create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2)
     await self.discovery1.start()
     await self.discovery2.start()
 
@@ -41,8 +41,8 @@ class TestUDPDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
     self.server2 = GRPCServer(self.node2, "localhost", 50054)
     await self.server1.start()
     await self.server2.start()
-    self.discovery1 = UDPDiscovery("discovery1", 50053, 5678, 5679, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
-    self.discovery2 = UDPDiscovery("discovery2", 50054, 5679, 5678, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
+    self.discovery1 = UDPDiscovery("discovery1", 50053, 5678, 5679, lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
+    self.discovery2 = UDPDiscovery("discovery2", 50054, 5679, 5678, lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
     await self.discovery1.start()
     await self.discovery2.start()
 

+ 12 - 11
exo/networking/udp/udp_discovery.py

@@ -7,7 +7,7 @@ from typing import List, Dict, Callable, Tuple, Coroutine
 from exo.networking.discovery import Discovery
 from exo.networking.peer_handle import PeerHandle
 from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
-from exo.helpers import DEBUG, DEBUG_DISCOVERY, get_all_ip_addresses
+from exo.helpers import DEBUG, DEBUG_DISCOVERY, get_all_ip_addresses_and_interfaces, get_interface_priority_and_type
 
 
 class ListenProtocol(asyncio.DatagramProtocol):
@@ -41,7 +41,7 @@ class UDPDiscovery(Discovery):
     node_port: int,
     listen_port: int,
     broadcast_port: int,
-    create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
+    create_peer_handle: Callable[[str, str, str, DeviceCapabilities], PeerHandle],
     broadcast_interval: int = 1,
     discovery_timeout: int = 30,
     device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
@@ -87,27 +87,27 @@ class UDPDiscovery(Discovery):
     while True:
       # Explicitly broadcasting on all assigned ips since broadcasting on `0.0.0.0` on MacOS does not broadcast over
       # the Thunderbolt bridge when other connection modalities exist such as WiFi or Ethernet
-      for addr in get_all_ip_addresses():
+      for addr, interface_name in get_all_ip_addresses_and_interfaces():
+        interface_priority, _ = get_interface_priority_and_type(interface_name)
         message = json.dumps({
           "type": "discovery",
           "node_id": self.node_id,
           "grpc_port": self.node_port,
           "device_capabilities": self.device_capabilities.to_dict(),
-          "priority": 1,  # For now, every interface has the same priority. We can make this better by prioriting interfaces based on bandwidth, latency, and jitter e.g. prioritise Thunderbolt over WiFi.
+          "priority": interface_priority, # TODO: Prioritise interfaces based on bandwidth, latency, and jitter e.g. prioritise Thunderbolt over WiFi.
+          "interface_name": interface_name,
         })
-        if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr}): {message}")
+        if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr} - {interface_name} - {interface_priority}): {message}")
 
         transport = None
         try:
           transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(lambda: BroadcastProtocol(message, self.broadcast_port), local_addr=(addr, 0), family=socket.AF_INET)
-          if DEBUG_DISCOVERY >= 3:
-            print(f"Broadcasting presence at ({addr})")
+          if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr} - {interface_name} - {interface_priority})")
         except Exception as e:
-          print(f"Error in broadcast presence ({addr}): {e}")
+          print(f"Error in broadcast presence ({addr} - {interface_name} - {interface_priority}): {e}")
         finally:
           if transport:
-            try:
-              transport.close()
+            try: transport.close()
             except Exception as e:
               if DEBUG_DISCOVERY >= 2: print(f"Error closing transport: {e}")
               if DEBUG_DISCOVERY >= 2: traceback.print_exc()
@@ -144,6 +144,7 @@ class UDPDiscovery(Discovery):
       peer_host = addr[0]
       peer_port = message["grpc_port"]
       peer_prio = message["priority"]
+      peer_interface_name = message["interface_name"]
       device_capabilities = DeviceCapabilities(**message["device_capabilities"])
 
       if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
@@ -153,7 +154,7 @@ class UDPDiscovery(Discovery):
             if DEBUG >= 1:
               print(f"Ignoring peer {peer_id} at {peer_host}:{peer_port} with priority {peer_prio} because we already know about a peer with higher or equal priority: {existing_peer_prio}")
             return
-        new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
+        new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", peer_interface_name, device_capabilities)
         if not await new_peer_handle.health_check():
           if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} is not healthy. Skipping.")
           return

+ 1 - 1
exo/orchestration/standard_node.py

@@ -398,7 +398,7 @@ class StandardNode(Node):
 
     for peer in self.peers:
       next_topology.update_node(peer.id(), peer.device_capabilities())
-      next_topology.add_edge(self.id, peer.id())
+      next_topology.add_edge(self.id, peer.id(), peer.description())
 
       if peer.id() in prev_visited:
         continue

+ 40 - 15
exo/topology/topology.py

@@ -1,11 +1,28 @@
 from .device_capabilities import DeviceCapabilities
-from typing import Dict, Set, Optional
+from typing import Dict, Set, Optional, NamedTuple
+from dataclasses import dataclass
 
+@dataclass
+class PeerConnection:
+  from_id: str
+  to_id: str
+  description: Optional[str] = None
+
+  def __hash__(self):
+    # Use both from_id and to_id for uniqueness in sets
+    return hash((self.from_id, self.to_id))
+
+  def __eq__(self, other):
+    if not isinstance(other, PeerConnection):
+      return False
+    # Compare both from_id and to_id for equality
+    return self.from_id == other.from_id and self.to_id == other.to_id
 
 class Topology:
   def __init__(self):
-    self.nodes: Dict[str, DeviceCapabilities] = {}  # Maps node IDs to DeviceCapabilities
-    self.peer_graph: Dict[str, Set[str]] = {}  # Adjacency list representing the graph
+    self.nodes: Dict[str, DeviceCapabilities] = {}
+    # Store PeerConnection objects in the adjacency lists
+    self.peer_graph: Dict[str, Set[PeerConnection]] = {}
     self.active_node_id: Optional[str] = None
 
   def update_node(self, node_id: str, device_capabilities: DeviceCapabilities):
@@ -17,33 +34,41 @@ class Topology:
   def all_nodes(self):
     return self.nodes.items()
 
-  def add_edge(self, node1_id: str, node2_id: str):
+  def add_edge(self, node1_id: str, node2_id: str, description: Optional[str] = None):
     if node1_id not in self.peer_graph:
       self.peer_graph[node1_id] = set()
     if node2_id not in self.peer_graph:
       self.peer_graph[node2_id] = set()
-    self.peer_graph[node1_id].add(node2_id)
-    self.peer_graph[node2_id].add(node1_id)
+
+    # Create bidirectional connections with the same description
+    conn1 = PeerConnection(node1_id, node2_id, description)
+    conn2 = PeerConnection(node2_id, node1_id, description)
+
+    self.peer_graph[node1_id].add(conn1)
+    self.peer_graph[node2_id].add(conn2)
 
   def get_neighbors(self, node_id: str) -> Set[str]:
-    return self.peer_graph.get(node_id, set())
+    # Convert PeerConnection objects back to just destination IDs
+    return {conn.to_id for conn in self.peer_graph.get(node_id, set())}
 
   def all_edges(self):
     edges = []
-    for node, neighbors in self.peer_graph.items():
-      for neighbor in neighbors:
-        if (neighbor, node) not in edges:  # Avoid duplicate edges
-          edges.append((node, neighbor))
+    for node_id, connections in self.peer_graph.items():
+      for conn in connections:
+        # Only include each edge once by checking if reverse already exists
+        if not any(e[0] == conn.to_id and e[1] == conn.from_id for e in edges):
+          edges.append((conn.from_id, conn.to_id, conn.description))
     return edges
 
   def merge(self, other: "Topology"):
     for node_id, capabilities in other.nodes.items():
       self.update_node(node_id, capabilities)
-    for node_id, neighbors in other.peer_graph.items():
-      for neighbor in neighbors:
-        self.add_edge(node_id, neighbor)
+    for node_id, connections in other.peer_graph.items():
+      for conn in connections:
+        self.add_edge(conn.from_id, conn.to_id, conn.description)
 
   def __str__(self):
     nodes_str = ", ".join(f"{node_id}: {cap}" for node_id, cap in self.nodes.items())
-    edges_str = ", ".join(f"{node}: {neighbors}" for node, neighbors in self.peer_graph.items())
+    edges_str = ", ".join(f"{node}: {[f'{c.to_id}({c.description})' for c in conns]}"
+                         for node, conns in self.peer_graph.items())
     return f"Topology(Nodes: {{{nodes_str}}}, Edges: {{{edges_str}}})"

+ 15 - 1
exo/viz/topology_viz.py

@@ -242,12 +242,19 @@ class TopologyViz:
             if info_y + j != y or info_x + k != x:
               visualization[info_y + j][info_x + k] = char
 
-      # Draw line to next node
+      # Draw line to next node and add connection description
       next_i = (i+1) % num_partitions
       next_angle = 2*math.pi*next_i/num_partitions
       next_x = int(center_x + radius_x*math.cos(next_angle))
       next_y = int(center_y + radius_y*math.sin(next_angle))
 
+      # Get connection descriptions
+      conn1 = self.topology.peer_graph.get(partition.node_id, set())
+      conn2 = self.topology.peer_graph.get(self.partitions[next_i].node_id, set())
+      description1 = next((c.description for c in conn1 if c.to_id == self.partitions[next_i].node_id), "")
+      description2 = next((c.description for c in conn2 if c.to_id == partition.node_id), "")
+      connection_description = f"{description1}/{description2}" if description1 != description2 else description1
+
       # Simple line drawing
       steps = max(abs(next_x - x), abs(next_y - y))
       for step in range(1, steps):
@@ -256,6 +263,13 @@ class TopologyViz:
         if 0 <= line_y < 48 and 0 <= line_x < 100:
           visualization[line_y][line_x] = "-"
 
+      # Add connection description near the midpoint of the line
+      mid_x = (x + next_x) // 2
+      mid_y = (y + next_y) // 2
+      for j, char in enumerate(connection_description):
+        if 0 <= mid_y < 48 and 0 <= mid_x + j < 100:
+          visualization[mid_y][mid_x + j] = char
+
     # Convert to string
     return "\n".join("".join(str(char) for char in row) for row in visualization)
 

Một số tệp đã không được hiển thị bởi vì quá nhiều tập tin thay đổi trong này khác