Przeglądaj źródła

Merge pull request #383 from ianpaul10/feat/manual-disc-follow-up

Support changing manual configuration while running
Alex Cheema 7 miesięcy temu
rodzic
commit
a174c78004

+ 52 - 22
exo/networking/manual/manual_discovery.py

@@ -1,7 +1,9 @@
+import os
 import asyncio
-from exo.networking.discovery import Discovery
-from typing import Dict, List, Callable
+from typing import Dict, List, Callable, Optional
+from concurrent.futures import ThreadPoolExecutor
 
+from exo.networking.discovery import Discovery
 from exo.topology.device_capabilities import DeviceCapabilities
 from exo.networking.manual.network_topology_config import NetworkTopology, PeerConfig
 from exo.helpers import DEBUG_DISCOVERY
@@ -13,28 +15,25 @@ class ManualDiscovery(Discovery):
     self,
     network_config_path: str,
     node_id: str,
-    create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
+    create_peer_handle: Callable[[str, str, str, DeviceCapabilities], PeerHandle],
   ):
-    self.topology = NetworkTopology.from_path(network_config_path)
+    self.network_config_path = network_config_path
+    self.node_id = node_id
     self.create_peer_handle = create_peer_handle
 
-    if node_id not in self.topology.peers:
-      raise ValueError(
-        f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}"
-      )
-
     self.listen_task = None
-
     self.known_peers: Dict[str, PeerHandle] = {}
-    self.peers_in_network: Dict[str, PeerConfig] = self.topology.peers
-    self.peers_in_network.pop(node_id)
+
+    self._cached_peers: Dict[str, PeerConfig] = {}
+    self._last_modified_time: Optional[float] = None
+    self._file_executor = ThreadPoolExecutor(max_workers=1)
 
   async def start(self) -> None:
     self.listen_task = asyncio.create_task(self.task_find_peers_from_config())
 
   async def stop(self) -> None:
-    if self.listen_task:
-      self.listen_task.cancel()
+    if self.listen_task: self.listen_task.cancel()
+    self._file_executor.shutdown(wait=True)
 
   async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
     if wait_for_peers > 0:
@@ -47,7 +46,9 @@ class ManualDiscovery(Discovery):
   async def task_find_peers_from_config(self):
     if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...")
     while True:
-      for peer_id, peer_config in self.peers_in_network.items():
+      peers_from_config = await self._get_peers()
+      new_known_peers = {}
+      for peer_id, peer_config in peers_from_config.items():
         try:
           if DEBUG_DISCOVERY >= 2: print(f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}")
           peer = self.known_peers.get(peer_id)
@@ -57,15 +58,44 @@ class ManualDiscovery(Discovery):
           is_healthy = await peer.health_check()
           if is_healthy:
             if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")
-            self.known_peers[peer_id] = peer
-          else:
-            if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy.")
-            try:
-              del self.known_peers[peer_id]
-            except KeyError:
-              pass
+            new_known_peers[peer_id] = peer
+          elif DEBUG_DISCOVERY >= 2:
+            print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy. Removing.")
         except Exception as e:
           if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}")
+      self.known_peers = new_known_peers
       await asyncio.sleep(1.0)
 
       if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}")
+
+  async def _get_peers(self):
+    try:
+      loop = asyncio.get_running_loop()
+      current_mtime = await loop.run_in_executor(self._file_executor, os.path.getmtime, self.network_config_path)
+
+      if (self._cached_peers is not None and self._last_modified_time is not None and current_mtime <= self._last_modified_time):
+        return self._cached_peers
+
+      topology = await loop.run_in_executor(self._file_executor, NetworkTopology.from_path, self.network_config_path)
+
+      if self.node_id not in topology.peers:
+        raise ValueError(
+          f"Node ID {self.node_id} not found in network config file "
+          f"{self.network_config_path}. Please run with `node_id` set to "
+          f"one of the keys in the config file: {[k for k, _ in topology.peers]}"
+        )
+
+      peers_in_network = topology.peers
+      peers_in_network.pop(self.node_id)
+
+      self._cached_peers = peers_in_network
+      self._last_modified_time = current_mtime
+
+      return peers_in_network
+
+    except Exception as e:
+      if DEBUG_DISCOVERY >= 2:
+        print(f"Error when loading network config file from {self.network_config_path}. "
+              f"Please update the config file in order to successfully discover peers. "
+              f"Exception: {e}")
+      return self._cached_peers

+ 1 - 1
exo/networking/manual/test_data/test_config.json

@@ -29,4 +29,4 @@
       }
     }
   }
-}
+}

+ 84 - 6
exo/networking/manual/test_manual_discovery.py

@@ -1,3 +1,4 @@
+import json
 import asyncio
 import unittest
 from unittest import mock
@@ -14,8 +15,12 @@ class TestSingleNodeManualDiscovery(unittest.IsolatedAsyncioTestCase):
   async def asyncSetUp(self):
     self.peer1 = mock.AsyncMock()
     self.peer1.connect = mock.AsyncMock()
-    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
-    _ = self.discovery1.start()
+    self.discovery1 = ManualDiscovery(
+      root_path,
+      "node1",
+      create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1,
+    )
+    await self.discovery1.start()
 
   async def asyncTearDown(self):
     await self.discovery1.stop()
@@ -33,8 +38,16 @@ class TestManualDiscovery(unittest.IsolatedAsyncioTestCase):
     self.peer2 = mock.AsyncMock()
     self.peer1.connect = mock.AsyncMock()
     self.peer2.connect = mock.AsyncMock()
-    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
-    self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2)
+    self.discovery1 = ManualDiscovery(
+      root_path,
+      "node1",
+      create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1,
+    )
+    self.discovery2 = ManualDiscovery(
+      root_path,
+      "node2",
+      create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2,
+    )
     await self.discovery1.start()
     await self.discovery2.start()
 
@@ -63,8 +76,16 @@ class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
     self.server2 = GRPCServer(self.node2, config.peers["node2"].address, config.peers["node2"].port)
     await self.server1.start()
     await self.server2.start()
-    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
-    self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
+    self.discovery1 = ManualDiscovery(
+      root_path,
+      "node1",
+      create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
+    )
+    self.discovery2 = ManualDiscovery(
+      root_path,
+      "node2",
+      create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
+    )
     await self.discovery1.start()
     await self.discovery2.start()
 
@@ -98,6 +119,63 @@ class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
     self.assertFalse(await peers1[0].is_connected())
     self.assertFalse(await peers2[0].is_connected())
 
+  async def test_dynamic_config_update(self):
+    initial_peers = await self.discovery1.discover_peers(wait_for_peers=1)
+    self.assertEqual(len(initial_peers), 1)
+
+    # Save original config for cleanup
+    with open(root_path, "r") as f:
+      original_config = json.load(f)
+
+    try:
+      updated_config = {
+        "peers": {
+          **original_config["peers"],
+          "node3": {
+            "address": "localhost",
+            "port": 50053,
+            "device_capabilities": {
+              "model": "Unknown Model",
+              "chip": "Unknown Chip",
+              "memory": 0,
+              "flops": {"fp32": 0, "fp16": 0, "int8": 0},
+            },
+          },
+        }
+      }
+
+      with open(root_path, "w") as f:
+        json.dump(updated_config, f, indent=2)
+
+      node3 = mock.AsyncMock(spec=Node)
+      server3 = GRPCServer(node3, "localhost", 50053)
+      await server3.start()
+
+      try:
+        # Wait for the config to be reloaded
+        await asyncio.sleep(1.5)
+
+        updated_peers = await self.discovery1.discover_peers(wait_for_peers=2)
+        self.assertEqual(len(updated_peers), 2)
+
+        for peer in updated_peers:
+          await peer.connect()
+          self.assertTrue(await peer.is_connected())
+
+      finally:
+        await server3.stop()
+
+    finally:
+      # Restore the original config file
+      with open(root_path, "w") as f:
+        json.dump(original_config, f, indent=2)
+
+    # Wait for the config to be reloaded again
+    await asyncio.sleep(1.5)
+
+    updated_peers = await self.discovery1.discover_peers(wait_for_peers=1)
+    self.assertEqual(len(updated_peers), 1)
+
 
 if __name__ == "__main__":
   asyncio.run(unittest.main())