瀏覽代碼

keep track of already visited peers in global operations: collect_topology

Alex Cheema 10 月之前
父節點
當前提交
d2184f583a

+ 2 - 2
exo/networking/grpc/grpc_peer_handle.py

@@ -77,8 +77,8 @@ class GRPCPeerHandle(PeerHandle):
         request = node_service_pb2.ResetShardRequest(shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers))
         request = node_service_pb2.ResetShardRequest(shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers))
         await self.stub.ResetShard(request)
         await self.stub.ResetShard(request)
 
 
-    async def collect_topology(self, max_depth: int) -> Topology:
-        request = node_service_pb2.CollectTopologyRequest(max_depth=max_depth)
+    async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
+        request = node_service_pb2.CollectTopologyRequest(visited=visited, max_depth=max_depth)
         response = await self.stub.CollectTopology(request)
         response = await self.stub.CollectTopology(request)
         topology = Topology()
         topology = Topology()
         for node_id, capabilities in response.nodes.items():
         for node_id, capabilities in response.nodes.items():

+ 2 - 1
exo/networking/grpc/grpc_server.py

@@ -65,7 +65,8 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
 
 
     async def CollectTopology(self, request, context):
     async def CollectTopology(self, request, context):
         max_depth = request.max_depth
         max_depth = request.max_depth
-        topology = await self.node.collect_topology(max_depth)
+        visited = set(request.visited)
+        topology = await self.node.collect_topology(visited, max_depth)
         nodes = {node_id: node_service_pb2.DeviceCapabilities(model=cap.model, chip=cap.chip, memory=cap.memory) for node_id, cap in topology.nodes.items()}
         nodes = {node_id: node_service_pb2.DeviceCapabilities(model=cap.model, chip=cap.chip, memory=cap.memory) for node_id, cap in topology.nodes.items()}
         peer_graph = {node_id: node_service_pb2.Peers(peer_ids=peers) for node_id, peers in topology.peer_graph.items()}
         peer_graph = {node_id: node_service_pb2.Peers(peer_ids=peers) for node_id, peers in topology.peer_graph.items()}
         return node_service_pb2.Topology(nodes=nodes, peer_graph=peer_graph)
         return node_service_pb2.Topology(nodes=nodes, peer_graph=peer_graph)

+ 2 - 1
exo/networking/grpc/node_service.proto

@@ -49,7 +49,8 @@ message ResetShardRequest {
 }
 }
 
 
 message CollectTopologyRequest {
 message CollectTopologyRequest {
-  int32 max_depth = 1;
+  repeated string visited = 1;
+  int32 max_depth = 2;
 }
 }
 
 
 message Topology {
 message Topology {

文件差異過大導致無法顯示
+ 0 - 0
exo/networking/grpc/node_service_pb2.py


+ 1 - 1
exo/networking/peer_handle.py

@@ -42,5 +42,5 @@ class PeerHandle(ABC):
     async def reset_shard(self, shard: Shard) -> None:
     async def reset_shard(self, shard: Shard) -> None:
         pass
         pass
 
 
-    async def collect_topology(self, max_depth: int) -> Topology:
+    async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
         pass
         pass

+ 5 - 1
exo/orchestration/node.py

@@ -26,9 +26,13 @@ class Node(ABC):
         pass
         pass
 
 
     @abstractmethod
     @abstractmethod
-    async def collect_topology(self, max_depth: int = 2) -> Topology:
+    async def collect_topology(self, visited: set[str] = set(), max_depth: int = 2) -> Topology:
         pass
         pass
 
 
     @abstractmethod
     @abstractmethod
     async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
     async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
         pass
         pass
+
+    @abstractmethod
+    async def global_reset(self, visited: set[str], max_depth: int = 2) -> None:
+        pass

+ 20 - 8
exo/orchestration/standard_node.py

@@ -147,20 +147,29 @@ class StandardNode(Node):
                 await peer.connect()
                 await peer.connect()
                 if DEBUG >= 2: print(f"Connected to peer {peer.id()}")
                 if DEBUG >= 2: print(f"Connected to peer {peer.id()}")
 
 
-    async def collect_topology(self, max_depth: int = 4) -> Topology:
+    async def collect_topology(self, visited: set[str] = set(), max_depth: int = 4) -> Topology:
         self.topology.update_node(self.id, self.device_capabilities)
         self.topology.update_node(self.id, self.device_capabilities)
 
 
+        if DEBUG >= 2: print(f"Collecting topoloy {max_depth=} {visited=}")
         for peer in self.peers:
         for peer in self.peers:
             self.topology.update_node(peer.id(), peer.device_capabilities())
             self.topology.update_node(peer.id(), peer.device_capabilities())
             self.topology.add_edge(self.id, peer.id())
             self.topology.add_edge(self.id, peer.id())
 
 
-            if max_depth > 0:
-                try:
-                    other_topology = await peer.collect_topology(max_depth = max_depth - 1)
-                    if DEBUG >= 2: print(f"Collected topology from: {peer.id()}: {other_topology}")
-                    self.topology.merge(other_topology)
-                except Exception as e:
-                    print(f"Error collecting topology from {peer.id()}: {e}")
+            if peer.id() in visited:
+                if DEBUG >= 2: print(f"Already visited {peer.id()}. Skipping...")
+                continue
+            visited.add(peer.id())
+
+            if max_depth <= 0:
+                if DEBUG >= 2: print(f"Max depth reached. Skipping...")
+                continue
+
+            try:
+                other_topology = await peer.collect_topology(visited, max_depth = max_depth - 1)
+                if DEBUG >= 2: print(f"Collected topology from: {peer.id()}: {other_topology}")
+                self.topology.merge(other_topology)
+            except Exception as e:
+                print(f"Error collecting topology from {peer.id()}: {e}")
 
 
         return self.topology
         return self.topology
 
 
@@ -180,3 +189,6 @@ class StandardNode(Node):
         if request_id not in self.buffered_token_output:
         if request_id not in self.buffered_token_output:
             return None, False
             return None, False
         return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1]
         return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1]
+
+    async def global_reset(self, max_depth: int = 2) -> None:
+        pass

部分文件因文件數量過多而無法顯示