Browse Source

handle when a peer is removed from config, so the known_peers dict gets updated accordingly

Ian Paul 7 tháng trước cách đây
mục cha
commit
e5eb3259a5

+ 67 - 14
exo/networking/manual/manual_discovery.py

@@ -1,6 +1,7 @@
+import os
 import asyncio
 import asyncio
 from exo.networking.discovery import Discovery
 from exo.networking.discovery import Discovery
-from typing import Dict, List, Callable
+from typing import Dict, List, Callable, Optional
 
 
 from exo.topology.device_capabilities import DeviceCapabilities
 from exo.topology.device_capabilities import DeviceCapabilities
 from exo.networking.manual.network_topology_config import NetworkTopology, PeerConfig
 from exo.networking.manual.network_topology_config import NetworkTopology, PeerConfig
@@ -27,12 +28,16 @@ class ManualDiscovery(Discovery):
         self.listen_task = None
         self.listen_task = None
         self.known_peers: Dict[str, PeerHandle] = {}
         self.known_peers: Dict[str, PeerHandle] = {}
 
 
-    async def start(self) -> None:
+      self._cached_peers: Dict[str, PeerConfig] = {}
+    self._last_modified_time: Optional[float] = None
+
+  async def start(self) -> None:
         self.listen_task = asyncio.create_task(self.task_find_peers_from_config())
         self.listen_task = asyncio.create_task(self.task_find_peers_from_config())
+    self.cleanup_task = asyncio.create_task(self.task_clean_up_peers_from_config())
 
 
-    async def stop(self) -> None:
-        if self.listen_task:
-            self.listen_task.cancel()
+  async def stop(self) -> None:
+    if self.listen_task: self.listen_task.cancel()
+    if self.cleanup_task: self.cleanup_task.cancel()
 
 
     async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
     async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
         if wait_for_peers > 0:
         if wait_for_peers > 0:
@@ -48,6 +53,41 @@ class ManualDiscovery(Discovery):
             )
             )
         return list(self.known_peers.values())
         return list(self.known_peers.values())
 
 
+  async def task_clean_up_peers_from_config(self):
+    if DEBUG_DISCOVERY >= 2: print("Starting task to clean up peers from config...")
+    while True:
+      peers_from_config = self._get_peers()
+      if peers_from_config:
+        peers_to_remove = [peer for peer in self.known_peers.keys() if peer not in peers_from_config]
+
+        for peer in peers_to_remove:
+          if DEBUG_DISCOVERY >= 2: print(f"{peer} is no longer found in the config but is currently a known peer. Removing from known peers...")
+          try: del self.known_peers[peer]
+          except KeyError: pass
+
+      await asyncio.sleep(1.0)
+
+  async def task_find_peers_from_config(self):
+    if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...")
+    while True:
+      for peer_id, peer_config in self._get_peers().items():
+        try:
+          if DEBUG_DISCOVERY >= 2: print(f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}")
+          peer = self.known_peers.get(peer_id)
+          if not peer:
+            if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} not found in known peers. Adding.")
+            peer = self.create_peer_handle(peer_id, f"{peer_config.address}:{peer_config.port}", peer_config.device_capabilities)
+          is_healthy = await peer.health_check()
+          if is_healthy:
+            if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")
+            self.known_peers[peer_id] = peer
+          else:
+            if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy.")
+            try: del self.known_peers[peer_id]
+            except KeyError: pass
+        except Exception as e:
+          if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}")
+      await asyncio.sleep(1.0)
     async def task_find_peers_from_config(self):
     async def task_find_peers_from_config(self):
         if DEBUG_DISCOVERY >= 2:
         if DEBUG_DISCOVERY >= 2:
             print("Starting task to find peers from config...")
             print("Starting task to find peers from config...")
@@ -103,15 +143,28 @@ class ManualDiscovery(Discovery):
 
 
   def _get_peers(self):
   def _get_peers(self):
     try:
     try:
-        topology = NetworkTopology.from_path(self.network_config_path)
+      current_mtime = os.path.getmtime(self.network_config_path)
 
 
-        if self.node_id not in topology.peers:
-          raise ValueError(f"Node ID {self.node_id} not found in network config file {self.network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in topology.peers]}")
+      if self._cached_peers is not None and self._last_modified_time is not None and current_mtime <= self._last_modified_time:
+        return self._cached_peers
 
 
-        peers_in_network: Dict[str, PeerConfig] = topology.peers
-        peers_in_network.pop(self.node_id)
-    except Exception as e:
-        if DEBUG_DISCOVERY >= 2: print(f"Error when loading network config file from {self.network_config_path}. Please update the config file in order to successfully discover peers. Exception: {e}")
-        peers_in_network = {}
+      topology = NetworkTopology.from_path(self.network_config_path)
 
 
-        return peers_in_network
+      if self.node_id not in topology.peers:
+        raise ValueError(
+          f"Node ID {self.node_id} not found in network config file "
+          f"{self.network_config_path}. Please run with `node_id` set to "
+          f"one of the keys in the config file: {[k for k, _ in topology.peers]}"
+        )
+
+      peers_in_network: Dict[str, PeerConfig] = topology.peers
+      peers_in_network.pop(self.node_id)
+
+      self._cached_peers = peers_in_network
+      self._last_modified_time = current_mtime
+
+      return peers_in_network
+
+    except Exception as e:
+      if DEBUG_DISCOVERY >= 2: print(f"Error when loading network config file from {self.network_config_path}. Please update the config file in order to successfully discover peers. Exception: {e}")
+      return self._cached_peers

+ 7 - 0
exo/networking/manual/test_manual_discovery.py

@@ -145,6 +145,13 @@ class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
       with open(root_path, "w") as f:
       with open(root_path, "w") as f:
         json.dump(original_config, f, indent=2)
         json.dump(original_config, f, indent=2)
 
 
+    # Wait for the config to be reloaded again
+    await asyncio.sleep(1.5)
+
+    updated_peers = await self.discovery1.discover_peers(wait_for_peers=1)
+    self.assertEqual(len(updated_peers), 1)
+
+
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
   asyncio.run(unittest.main())
   asyncio.run(unittest.main())