Przeglądaj źródła

ignore topology merges from the non-owner

Alex Cheema 5 miesięcy temu
rodzic
commit
69c18d9a93

+ 1 - 3
exo/networking/grpc/grpc_peer_handle.py

@@ -117,12 +117,11 @@ class GRPCPeerHandle(PeerHandle):
       response.is_finished,
     )
 
-  async def collect_topology(self, my_node_id: str, visited: set[str], max_depth: int) -> Topology:
+  async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
     request = node_service_pb2.CollectTopologyRequest(visited=visited, max_depth=max_depth)
     response = await self.stub.CollectTopology(request)
     topology = Topology()
     for node_id, capabilities in response.nodes.items():
-      if node_id == my_node_id: continue
       device_capabilities = DeviceCapabilities(
         model=capabilities.model,
         chip=capabilities.chip,
@@ -131,7 +130,6 @@ class GRPCPeerHandle(PeerHandle):
       )
       topology.update_node(node_id, device_capabilities)
     for node_id, peer_connections in response.peer_graph.items():
-      if node_id == my_node_id: continue
       for conn in peer_connections.connections:
         topology.add_edge(node_id, conn.to_id, conn.description)
     return topology

+ 1 - 1
exo/networking/peer_handle.py

@@ -56,5 +56,5 @@ class PeerHandle(ABC):
     pass
 
   @abstractmethod
-  async def collect_topology(self, my_node_id: str, visited: set[str], max_depth: int) -> Topology:
+  async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
     pass

+ 1 - 1
exo/orchestration/node.py

@@ -28,7 +28,7 @@ class Node(ABC):
     pass
 
   @abstractmethod
-  async def collect_topology(self, my_node_id: str, visited: set[str] = set(), max_depth: int = 2) -> Topology:
+  async def collect_topology(self, visited: set[str] = set(), max_depth: int = 2) -> Topology:
     pass
 
   @property

+ 2 - 2
exo/orchestration/standard_node.py

@@ -408,9 +408,9 @@ class StandardNode(Node):
         continue
 
       try:
-        other_topology = await asyncio.wait_for(peer.collect_topology(self.id, visited, max_depth=max_depth - 1), timeout=5.0)
+        other_topology = await asyncio.wait_for(peer.collect_topology(visited, max_depth=max_depth - 1), timeout=5.0)
         if DEBUG >= 2: print(f"Collected topology from: {peer.id()}: {other_topology}")
-        next_topology.merge(other_topology)
+        next_topology.merge(peer.id(), other_topology)
       except Exception as e:
         print(f"Error collecting topology from {peer.id()}: {e}")
         traceback.print_exc()

+ 3 - 1
exo/topology/topology.py

@@ -39,11 +39,13 @@ class Topology:
     conn = PeerConnection(from_id, to_id, description)
     self.peer_graph[from_id].add(conn)
 
-  def merge(self, other: "Topology"):
+  def merge(self, peer_node_id: str, other: "Topology"):
     for node_id, capabilities in other.nodes.items():
+      if node_id != peer_node_id: continue
       self.update_node(node_id, capabilities)
     for node_id, connections in other.peer_graph.items():
       for conn in connections:
+        if conn.from_id != peer_node_id: continue
         self.add_edge(conn.from_id, conn.to_id, conn.description)
 
   def __str__(self):