|
@@ -147,20 +147,29 @@ class StandardNode(Node):
|
|
await peer.connect()
|
|
await peer.connect()
|
|
if DEBUG >= 2: print(f"Connected to peer {peer.id()}")
|
|
if DEBUG >= 2: print(f"Connected to peer {peer.id()}")
|
|
|
|
|
|
- async def collect_topology(self, max_depth: int = 4) -> Topology:
|
|
|
|
|
|
+ async def collect_topology(self, visited: set[str] = set(), max_depth: int = 4) -> Topology:
|
|
self.topology.update_node(self.id, self.device_capabilities)
|
|
self.topology.update_node(self.id, self.device_capabilities)
|
|
|
|
|
|
|
|
+ if DEBUG >= 2: print(f"Collecting topoloy {max_depth=} {visited=}")
|
|
for peer in self.peers:
|
|
for peer in self.peers:
|
|
self.topology.update_node(peer.id(), peer.device_capabilities())
|
|
self.topology.update_node(peer.id(), peer.device_capabilities())
|
|
self.topology.add_edge(self.id, peer.id())
|
|
self.topology.add_edge(self.id, peer.id())
|
|
|
|
|
|
- if max_depth > 0:
|
|
|
|
- try:
|
|
|
|
- other_topology = await peer.collect_topology(max_depth = max_depth - 1)
|
|
|
|
- if DEBUG >= 2: print(f"Collected topology from: {peer.id()}: {other_topology}")
|
|
|
|
- self.topology.merge(other_topology)
|
|
|
|
- except Exception as e:
|
|
|
|
- print(f"Error collecting topology from {peer.id()}: {e}")
|
|
|
|
|
|
+ if peer.id() in visited:
|
|
|
|
+ if DEBUG >= 2: print(f"Already visited {peer.id()}. Skipping...")
|
|
|
|
+ continue
|
|
|
|
+ visited.add(peer.id())
|
|
|
|
+
|
|
|
|
+ if max_depth <= 0:
|
|
|
|
+ if DEBUG >= 2: print(f"Max depth reached. Skipping...")
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ try:
|
|
|
|
+ other_topology = await peer.collect_topology(visited, max_depth = max_depth - 1)
|
|
|
|
+ if DEBUG >= 2: print(f"Collected topology from: {peer.id()}: {other_topology}")
|
|
|
|
+ self.topology.merge(other_topology)
|
|
|
|
+ except Exception as e:
|
|
|
|
+ print(f"Error collecting topology from {peer.id()}: {e}")
|
|
|
|
|
|
return self.topology
|
|
return self.topology
|
|
|
|
|
|
@@ -180,3 +189,6 @@ class StandardNode(Node):
|
|
if request_id not in self.buffered_token_output:
|
|
if request_id not in self.buffered_token_output:
|
|
return None, False
|
|
return None, False
|
|
return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1]
|
|
return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1]
|
|
|
|
+
|
|
|
|
+ async def global_reset(self, max_depth: int = 2) -> None:
|
|
|
|
+ pass
|