|
@@ -147,18 +147,38 @@ class StandardNode(Node):
|
|
await peer.connect()
|
|
await peer.connect()
|
|
if DEBUG >= 2: print(f"Connected to peer {peer.id()}")
|
|
if DEBUG >= 2: print(f"Connected to peer {peer.id()}")
|
|
|
|
|
|
|
|
+ async def periodic_topology_collection(self, interval: int):
|
|
|
|
+ while True:
|
|
|
|
+ await asyncio.sleep(interval)
|
|
|
|
+ try:
|
|
|
|
+ await self.update_peers()
|
|
|
|
+ await self.collect_topology()
|
|
|
|
+ except Exception as e:
|
|
|
|
+ print(f"Error collecting topology: {e}")
|
|
|
|
+
|
|
|
|
+ if DEBUG >= 2: print("Topology collection task executed.")
|
|
|
|
+ if DEBUG >= 2: print(f"Current topology: {self.topology}")
|
|
|
|
+
|
|
|
|
+ async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
|
|
|
|
+ if request_id not in self.buffered_token_output:
|
|
|
|
+ return None, False
|
|
|
|
+ return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1]
|
|
|
|
+
|
|
async def collect_topology(self, visited: set[str] = set(), max_depth: int = 4) -> Topology:
|
|
async def collect_topology(self, visited: set[str] = set(), max_depth: int = 4) -> Topology:
|
|
self.topology.update_node(self.id, self.device_capabilities)
|
|
self.topology.update_node(self.id, self.device_capabilities)
|
|
|
|
|
|
if DEBUG >= 2: print(f"Collecting topoloy {max_depth=} {visited=}")
|
|
if DEBUG >= 2: print(f"Collecting topoloy {max_depth=} {visited=}")
|
|
|
|
+
|
|
|
|
+ prev_visited = visited.copy()
|
|
|
|
+ visited.update(p.id() for p in self.peers)
|
|
|
|
+
|
|
for peer in self.peers:
|
|
for peer in self.peers:
|
|
self.topology.update_node(peer.id(), peer.device_capabilities())
|
|
self.topology.update_node(peer.id(), peer.device_capabilities())
|
|
self.topology.add_edge(self.id, peer.id())
|
|
self.topology.add_edge(self.id, peer.id())
|
|
|
|
|
|
- if peer.id() in visited:
|
|
|
|
|
|
+ if peer.id() in prev_visited:
|
|
if DEBUG >= 2: print(f"Already visited {peer.id()}. Skipping...")
|
|
if DEBUG >= 2: print(f"Already visited {peer.id()}. Skipping...")
|
|
continue
|
|
continue
|
|
- visited.add(peer.id())
|
|
|
|
|
|
|
|
if max_depth <= 0:
|
|
if max_depth <= 0:
|
|
if DEBUG >= 2: print(f"Max depth reached. Skipping...")
|
|
if DEBUG >= 2: print(f"Max depth reached. Skipping...")
|
|
@@ -173,22 +193,26 @@ class StandardNode(Node):
|
|
|
|
|
|
return self.topology
|
|
return self.topology
|
|
|
|
|
|
- async def periodic_topology_collection(self, interval: int):
|
|
|
|
- while True:
|
|
|
|
- await asyncio.sleep(interval)
|
|
|
|
- try:
|
|
|
|
- await self.update_peers()
|
|
|
|
- await self.collect_topology()
|
|
|
|
- except Exception as e:
|
|
|
|
- print(f"Error collecting topology: {e}")
|
|
|
|
|
|
+ # TODO: unify this and collect_topology as global actions
|
|
|
|
+ async def global_reset(self, base_shard: Shard, visited: set[str] = set(), max_depth: int = 2) -> None:
|
|
|
|
+ await self.reset_shard(self.get_current_shard(base_shard))
|
|
|
|
|
|
- if DEBUG >= 2: print("Topology collection task executed.")
|
|
|
|
- if DEBUG >= 2: print(f"Current topology: {self.topology}")
|
|
|
|
|
|
+ if DEBUG >= 2: print(f"Global reset {base_shard=} {max_depth=} {visited=}")
|
|
|
|
|
|
- async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
|
|
|
|
- if request_id not in self.buffered_token_output:
|
|
|
|
- return None, False
|
|
|
|
- return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1]
|
|
|
|
|
|
+ prev_visited = visited.copy()
|
|
|
|
+ visited.update(p.id() for p in self.peers)
|
|
|
|
|
|
- async def global_reset(self, max_depth: int = 2) -> None:
|
|
|
|
- pass
|
|
|
|
|
|
+ for peer in self.peers:
|
|
|
|
+ if peer.id() in prev_visited:
|
|
|
|
+ if DEBUG >= 2: print(f"Already visited {peer.id()}. Skipping...")
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ if max_depth <= 0:
|
|
|
|
+ if DEBUG >= 2: print(f"Max depth reached. Skipping...")
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ try:
|
|
|
|
+ print(f"Forwarding global reset to peer {peer.id()}")
|
|
|
|
+ await peer.global_reset(base_shard, visited, max_depth = max_depth - 1)
|
|
|
|
+ except Exception as e:
|
|
|
|
+ print(f"Error collecting topology from {peer.id()}: {e}")
|