浏览代码

topo fix only take your own as source of truth

Alex Cheema 5 月之前
父节点
当前提交
50e4a966e0

+ 3 - 1
exo/networking/grpc/grpc_peer_handle.py

@@ -117,11 +117,12 @@ class GRPCPeerHandle(PeerHandle):
       response.is_finished,
     )
 
-  async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
+  async def collect_topology(self, my_node_id: str, visited: set[str], max_depth: int) -> Topology:
     request = node_service_pb2.CollectTopologyRequest(visited=visited, max_depth=max_depth)
     response = await self.stub.CollectTopology(request)
     topology = Topology()
     for node_id, capabilities in response.nodes.items():
+      if node_id == my_node_id: continue
       device_capabilities = DeviceCapabilities(
         model=capabilities.model,
         chip=capabilities.chip,
@@ -130,6 +131,7 @@ class GRPCPeerHandle(PeerHandle):
       )
       topology.update_node(node_id, device_capabilities)
     for node_id, peer_connections in response.peer_graph.items():
+      if node_id == my_node_id: continue
       for conn in peer_connections.connections:
         topology.add_edge(node_id, conn.to_id, conn.description)
     return topology

+ 1 - 1
exo/networking/peer_handle.py

@@ -56,5 +56,5 @@ class PeerHandle(ABC):
     pass
 
   @abstractmethod
-  async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
+  async def collect_topology(self, my_node_id: str, visited: set[str], max_depth: int) -> Topology:
     pass

+ 1 - 1
exo/orchestration/node.py

@@ -28,7 +28,7 @@ class Node(ABC):
     pass
 
   @abstractmethod
-  async def collect_topology(self, visited: set[str] = set(), max_depth: int = 2) -> Topology:
+  async def collect_topology(self, my_node_id: str, visited: set[str] = set(), max_depth: int = 2) -> Topology:
     pass
 
   @property

+ 1 - 1
exo/orchestration/standard_node.py

@@ -408,7 +408,7 @@ class StandardNode(Node):
         continue
 
       try:
-        other_topology = await asyncio.wait_for(peer.collect_topology(visited, max_depth=max_depth - 1), timeout=5.0)
+        other_topology = await asyncio.wait_for(peer.collect_topology(self.id, visited, max_depth=max_depth - 1), timeout=5.0)
         if DEBUG >= 2: print(f"Collected topology from: {peer.id()}: {other_topology}")
         next_topology.merge(other_topology)
       except Exception as e: