Browse Source

allow update to manual discovery file

re-load manual discovery file for each runthrough of the peer network, allowing incremental updates to the peer file even when exo is running
Ian Paul 8 months ago
parent
commit
98118babae

+ 96 - 52
exo/networking/manual/manual_discovery.py

@@ -9,63 +9,107 @@ from exo.networking.peer_handle import PeerHandle
 
 
 class ManualDiscovery(Discovery):
-  def __init__(
-    self,
-    network_config_path: str,
-    node_id: str,
-    create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
-  ):
-    self.topology = NetworkTopology.from_path(network_config_path)
-    self.create_peer_handle = create_peer_handle
+    def __init__(
+        self,
+        network_config_path: str,
+        node_id: str,
+        create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
+    ):
+        self.network_config_path = network_config_path
+        self.node_id = node_id
+        self.create_peer_handle = create_peer_handle
 
-    if node_id not in self.topology.peers:
-      raise ValueError(
-        f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}"
-      )
+        if node_id not in self.topology.peers:
+            raise ValueError(
+                f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}"
+            )
 
-    self.listen_task = None
+        self.listen_task = None
+        self.known_peers: Dict[str, PeerHandle] = {}
 
-    self.known_peers: Dict[str, PeerHandle] = {}
-    self.peers_in_network: Dict[str, PeerConfig] = self.topology.peers
-    self.peers_in_network.pop(node_id)
+    async def start(self) -> None:
+        self.listen_task = asyncio.create_task(self.task_find_peers_from_config())
 
-  async def start(self) -> None:
-    self.listen_task = asyncio.create_task(self.task_find_peers_from_config())
+    async def stop(self) -> None:
+        if self.listen_task:
+            self.listen_task.cancel()
 
-  async def stop(self) -> None:
-    if self.listen_task:
-      self.listen_task.cancel()
+    async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
+        if wait_for_peers > 0:
+            while len(self.known_peers) < wait_for_peers:
+                if DEBUG_DISCOVERY >= 2:
+                    print(
+                        f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers..."
+                    )
+                await asyncio.sleep(0.1)
+        if DEBUG_DISCOVERY >= 2:
+            print(
+                f"Discovered peers: {[peer.id() for peer in self.known_peers.values()]}"
+            )
+        return list(self.known_peers.values())
 
-  async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
-    if wait_for_peers > 0:
-      while len(self.known_peers) < wait_for_peers:
-        if DEBUG_DISCOVERY >= 2: print(f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers...")
-        await asyncio.sleep(0.1)
-    if DEBUG_DISCOVERY >= 2: print(f"Discovered peers: {[peer.id() for peer in self.known_peers.values()]}")
-    return list(self.known_peers.values())
+    async def task_find_peers_from_config(self):
+        if DEBUG_DISCOVERY >= 2:
+            print("Starting task to find peers from config...")
+        while True:
+            peers = self._get_peers().items()
+            for peer_id, peer_config in peers:
+                try:
+                    if DEBUG_DISCOVERY >= 2:
+                        print(
+                            f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}"
+                        )
+                    peer = self.known_peers.get(peer_id)
+                    if not peer:
+                        if DEBUG_DISCOVERY >= 2:
+                            print(f"{peer_id=} not found in known peers. Adding.")
+                        peer = self.create_peer_handle(
+                            peer_id,
+                            f"{peer_config.address}:{peer_config.port}",
+                            peer_config.device_capabilities,
+                        )
+                        peer = self.create_peer_handle(
+                            peer_id,
+                            f"{peer_config.address}:{peer_config.port}",
+                            peer_config.device_capabilities,
+                        )
+                    is_healthy = await peer.health_check()
+                    if is_healthy:
+                        if DEBUG_DISCOVERY >= 2:
+                            print(
+                                f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy."
+                            )
+                        self.known_peers[peer_id] = peer
+                    else:
+                        if DEBUG_DISCOVERY >= 2:
+                            print(
+                                f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy."
+                            )
+                        try:
+                            del self.known_peers[peer_id]
+                        except KeyError:
+                            pass
+                except Exception as e:
+                    if DEBUG_DISCOVERY >= 2:
+                        print(
+                            f"Exception occured when attempting to add {peer_id=}: {e}"
+                        )
+            await asyncio.sleep(1.0)
 
-  async def task_find_peers_from_config(self):
-    if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...")
-    while True:
-      for peer_id, peer_config in self.peers_in_network.items():
-        try:
-          if DEBUG_DISCOVERY >= 2: print(f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}")
-          peer = self.known_peers.get(peer_id)
-          if not peer:
-            if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} not found in known peers. Adding.")
-            peer = self.create_peer_handle(peer_id, f"{peer_config.address}:{peer_config.port}", "MAN", peer_config.device_capabilities)
-          is_healthy = await peer.health_check()
-          if is_healthy:
-            if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")
-            self.known_peers[peer_id] = peer
-          else:
-            if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy.")
-            try:
-              del self.known_peers[peer_id]
-            except KeyError:
-              pass
-        except Exception as e:
-          if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}")
-      await asyncio.sleep(1.0)
+            if DEBUG_DISCOVERY >= 2:
+                print(
+                    f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}"
+                )
 
-      if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}")
+    def _get_peers(self):
+        topology = NetworkTopology.from_path(self.network_config_path)
+
+        if self.node_id not in topology.peers:
+            raise ValueError(
+                f"Node ID {self.node_id} not found in network config file {self.network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in topology.peers]}"
+            )
+
+        peers_in_network: Dict[str, PeerConfig] = topology.peers
+        peers_in_network.pop(self.node_id)
+
+        return peers_in_network

+ 1 - 1
exo/networking/manual/test_data/test_config.json

@@ -29,4 +29,4 @@
       }
     }
   }
-}
+}

+ 51 - 4
exo/networking/manual/test_manual_discovery.py

@@ -1,3 +1,4 @@
+import json
 import asyncio
 import unittest
 from unittest import mock
@@ -44,9 +45,9 @@ class TestManualDiscovery(unittest.IsolatedAsyncioTestCase):
 
   async def test_discovery(self):
     peers1 = await self.discovery1.discover_peers(wait_for_peers=1)
-    assert len(peers1) == 1
+    self.assertEqual(len(peers1), 1)
     peers2 = await self.discovery2.discover_peers(wait_for_peers=1)
-    assert len(peers2) == 1
+    self.assertEqual(len(peers2), 1)
 
     # connect has to be explicitly called after discovery
     self.peer1.connect.assert_not_called()
@@ -76,9 +77,9 @@ class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
 
   async def test_grpc_discovery(self):
     peers1 = await self.discovery1.discover_peers(wait_for_peers=1)
-    assert len(peers1) == 1
+    self.assertEqual(len(peers1), 1)
     peers2 = await self.discovery2.discover_peers(wait_for_peers=1)
-    assert len(peers2) == 1
+    self.assertEqual(len(peers2), 1)
 
     # Connect
     await peers1[0].connect()
@@ -98,6 +99,52 @@ class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
     self.assertFalse(await peers1[0].is_connected())
     self.assertFalse(await peers2[0].is_connected())
 
+  async def test_dynamic_config_update(self):
+    initial_peers = await self.discovery1.discover_peers(wait_for_peers=1)
+    self.assertEqual(len(initial_peers), 1)
+
+    # Save original config for cleanup
+    with open(root_path, "r") as f:
+      original_config = json.load(f)
+
+    try:
+      updated_config = {
+        "peers": {
+          **original_config["peers"],
+          "node3": {
+            "address": "localhost",
+            "port": 50053,
+            "device_capabilities": {"model": "Unknown Model", "chip": "Unknown Chip", "memory": 0, "flops": {"fp32": 0, "fp16": 0, "int8": 0}},
+          },
+        }
+      }
+
+      with open(root_path, "w") as f:
+        json.dump(updated_config, f, indent=2)
+
+      node3 = mock.AsyncMock(spec=Node)
+      server3 = GRPCServer(node3, "localhost", 50053)
+      await server3.start()
+
+      try:
+        # Wait for the config to be reloaded
+        await asyncio.sleep(1.5)
+
+        updated_peers = await self.discovery1.discover_peers(wait_for_peers=2)
+        self.assertEqual(len(updated_peers), 2)
+
+        for peer in updated_peers:
+          await peer.connect()
+          self.assertTrue(await peer.is_connected())
+
+      finally:
+        await server3.stop()
+
+    finally:
+      # Restore the original config file
+      with open(root_path, "w") as f:
+        json.dump(original_config, f, indent=2)
+
 
 if __name__ == "__main__":
   asyncio.run(unittest.main())