Procházet zdrojové kódy

patch after rebasing to main

Ian Paul před 6 měsíci
rodič
revize
0e34ce2169
1 změnil soubory, kde provedl 51 přidání a 13 odebrání
  1. 51 13
      exo/networking/manual/manual_discovery.py

+ 51 - 13
exo/networking/manual/manual_discovery.py

@@ -1,6 +1,7 @@
+import os
 import asyncio
 import asyncio
 from exo.networking.discovery import Discovery
 from exo.networking.discovery import Discovery
-from typing import Dict, List, Callable
+from typing import Dict, List, Callable, Optional
 
 
 from exo.topology.device_capabilities import DeviceCapabilities
 from exo.topology.device_capabilities import DeviceCapabilities
 from exo.networking.manual.network_topology_config import NetworkTopology, PeerConfig
 from exo.networking.manual.network_topology_config import NetworkTopology, PeerConfig
@@ -15,28 +16,24 @@ class ManualDiscovery(Discovery):
     node_id: str,
     node_id: str,
     create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
     create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
   ):
   ):
-    self.topology = NetworkTopology.from_path(network_config_path)
     self.network_config_path = network_config_path
     self.network_config_path = network_config_path
     self.node_id = node_id
     self.node_id = node_id
     self.create_peer_handle = create_peer_handle
     self.create_peer_handle = create_peer_handle
 
 
-    if node_id not in self.topology.peers:
-      raise ValueError(
-        f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}"
-      )
-
     self.listen_task = None
     self.listen_task = None
-
+    self.cleanup_task = None
     self.known_peers: Dict[str, PeerHandle] = {}
     self.known_peers: Dict[str, PeerHandle] = {}
-    self.peers_in_network: Dict[str, PeerConfig] = self.topology.peers
-    self.peers_in_network.pop(node_id)
+
+    self._cached_peers: Dict[str, PeerConfig] = {}
+    self._last_modified_time: Optional[float] = None
 
 
   async def start(self) -> None:
   async def start(self) -> None:
     self.listen_task = asyncio.create_task(self.task_find_peers_from_config())
     self.listen_task = asyncio.create_task(self.task_find_peers_from_config())
+    self.cleanup_task = asyncio.create_task(self.task_clean_up_peers_from_config())
 
 
   async def stop(self) -> None:
   async def stop(self) -> None:
-    if self.listen_task:
-      self.listen_task.cancel()
+    if self.listen_task: self.listen_task.cancel()
+    if self.cleanup_task: self.cleanup_task.cancel()
 
 
   async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
   async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
     if wait_for_peers > 0:
     if wait_for_peers > 0:
@@ -49,7 +46,7 @@ class ManualDiscovery(Discovery):
   async def task_find_peers_from_config(self):
   async def task_find_peers_from_config(self):
     if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...")
     if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...")
     while True:
     while True:
-      for peer_id, peer_config in self.peers_in_network.items():
+      for peer_id, peer_config in self._get_peers().items():
         try:
         try:
           if DEBUG_DISCOVERY >= 2: print(f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}")
           if DEBUG_DISCOVERY >= 2: print(f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}")
           peer = self.known_peers.get(peer_id)
           peer = self.known_peers.get(peer_id)
@@ -72,3 +69,44 @@ class ManualDiscovery(Discovery):
 
 
       if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}")
       if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}")
 
 
+  async def task_clean_up_peers_from_config(self):
+    if DEBUG_DISCOVERY >= 2: print("Starting task to clean up peers from config...")
+    while True:
+      peers_from_config = self._get_peers()
+      if peers_from_config:
+        peers_to_remove = [peer for peer in self.known_peers.keys() if peer not in peers_from_config]
+
+        for peer in peers_to_remove:
+          if DEBUG_DISCOVERY >= 2: print(f"{peer} is no longer found in the config but is currently a known peer. Removing from known peers...")
+          try: del self.known_peers[peer]
+          except KeyError: pass
+
+      await asyncio.sleep(1.0)
+
+  def _get_peers(self):
+    try:
+      current_mtime = os.path.getmtime(self.network_config_path)
+
+      if self._cached_peers is not None and self._last_modified_time is not None and current_mtime <= self._last_modified_time:
+        return self._cached_peers
+
+      topology = NetworkTopology.from_path(self.network_config_path)
+
+      if self.node_id not in topology.peers:
+        raise ValueError(
+          f"Node ID {self.node_id} not found in network config file "
+          f"{self.network_config_path}. Please run with `node_id` set to "
+          f"one of the keys in the config file: {[k for k, _ in topology.peers]}"
+        )
+
+      peers_in_network: Dict[str, PeerConfig] = topology.peers
+      peers_in_network.pop(self.node_id)
+
+      self._cached_peers = peers_in_network
+      self._last_modified_time = current_mtime
+
+      return peers_in_network
+
+    except Exception as e:
+      if DEBUG_DISCOVERY >= 2: print(f"Error when loading network config file from {self.network_config_path}. Please update the config file in order to successfully discover peers. Exception: {e}")
+      return self._cached_peers