|
@@ -56,7 +56,7 @@ class StandardNode(Node):
|
|
|
await self.server.start()
|
|
|
await self.discovery.start()
|
|
|
await self.update_peers(wait_for_peers)
|
|
|
- await self.collect_topology()
|
|
|
+ await self.collect_topology(set())
|
|
|
if DEBUG >= 2: print(f"Collected topology: {self.topology}")
|
|
|
asyncio.create_task(self.periodic_topology_collection(1.0))
|
|
|
|
|
@@ -83,7 +83,7 @@ class StandardNode(Node):
|
|
|
download_progress = RepoProgressEvent.from_dict(status_data.get('progress'))
|
|
|
self.node_download_progress[status_data.get('node_id')] = download_progress
|
|
|
if self.topology_viz:
|
|
|
- self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), self.id, self.node_download_progress)
|
|
|
+ self.topology_viz.update_visualization(self.topology, self.partitioning_strategy.partition(self.topology), self.id, self.node_download_progress)
|
|
|
except Exception as e:
|
|
|
if DEBUG >= 1: print(f"Error updating visualization: {e}")
|
|
|
if DEBUG >= 1: traceback.print_exc()
|
|
@@ -374,8 +374,8 @@ class StandardNode(Node):
|
|
|
try:
|
|
|
did_peers_change = await self.update_peers()
|
|
|
if DEBUG >= 2: print(f"{did_peers_change=}")
|
|
|
+ await self.collect_topology(set())
|
|
|
if did_peers_change:
|
|
|
- await self.collect_topology()
|
|
|
await self.select_best_inference_engine()
|
|
|
except Exception as e:
|
|
|
print(f"Error collecting topology: {e}")
|
|
@@ -386,7 +386,7 @@ class StandardNode(Node):
|
|
|
return None, False
|
|
|
return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1]
|
|
|
|
|
|
- async def collect_topology(self, visited: set[str] = set(), max_depth: int = 4) -> Topology:
|
|
|
+ async def collect_topology(self, visited: set[str], max_depth: int = 4) -> Topology:
|
|
|
next_topology = Topology()
|
|
|
next_topology.update_node(self.id, self.device_capabilities)
|
|
|
|
|
@@ -410,16 +410,16 @@ class StandardNode(Node):
|
|
|
try:
|
|
|
other_topology = await asyncio.wait_for(peer.collect_topology(visited, max_depth=max_depth - 1), timeout=5.0)
|
|
|
if DEBUG >= 2: print(f"Collected topology from: {peer.id()}: {other_topology}")
|
|
|
- self.topology.merge(other_topology)
|
|
|
+ next_topology.merge(other_topology)
|
|
|
except Exception as e:
|
|
|
print(f"Error collecting topology from {peer.id()}: {e}")
|
|
|
traceback.print_exc()
|
|
|
|
|
|
- next_topology.active_node_id = self.topology.active_node_id # this is not so clean.
|
|
|
+ next_topology.active_node_id = self.topology.active_node_id
|
|
|
self.topology = next_topology
|
|
|
if self.topology_viz:
|
|
|
- self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), self.id)
|
|
|
- return next_topology
|
|
|
+ self.topology_viz.update_visualization(self.topology, self.partitioning_strategy.partition(self.topology), self.id)
|
|
|
+ return self.topology
|
|
|
|
|
|
@property
|
|
|
def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]:
|