瀏覽代碼

pass origin_node_id to merge

Alex Cheema 8 月之前
父節點
當前提交
657520ed4c
共有 2 個文件被更改,包括 6 次插入5 次删除
  1. 1 1
      exo/orchestration/standard_node.py
  2. 5 4
      exo/topology/topology.py

+ 1 - 1
exo/orchestration/standard_node.py

@@ -410,7 +410,7 @@ class StandardNode(Node):
       try:
         other_topology = await asyncio.wait_for(peer.collect_topology(visited, max_depth=max_depth - 1), timeout=5.0)
         if DEBUG >= 2: print(f"Collected topology from: {peer.id()}: {other_topology}")
-        next_topology.merge(other_topology)
+        next_topology.merge(self.id, other_topology)
       except Exception as e:
         print(f"Error collecting topology from {peer.id()}: {e}")
         traceback.print_exc()

+ 5 - 4
exo/topology/topology.py

@@ -21,7 +21,6 @@ class PeerConnection:
 class Topology:
   def __init__(self):
     self.nodes: Dict[str, DeviceCapabilities] = {}
-    # Store PeerConnection objects in the adjacency lists
     self.peer_graph: Dict[str, Set[PeerConnection]] = {}
     self.active_node_id: Optional[str] = None
 
@@ -40,12 +39,14 @@ class Topology:
     conn = PeerConnection(from_id, to_id, description)
     self.peer_graph[from_id].add(conn)
 
-  def merge(self, other: "Topology"):
+  def merge(self, origin_node_id: str, other: "Topology"):
     for node_id, capabilities in other.nodes.items():
-      self.update_node(node_id, capabilities)
+      if node_id != origin_node_id:
+        self.update_node(node_id, capabilities)
     for node_id, connections in other.peer_graph.items():
       for conn in connections:
-        self.add_edge(conn.from_id, conn.to_id, conn.description)
+        if conn.from_id != origin_node_id:
+          self.add_edge(conn.from_id, conn.to_id, conn.description)
 
   def __str__(self):
     nodes_str = ", ".join(f"{node_id}: {cap}" for node_id, cap in self.nodes.items())