Prechádzať zdrojové kódy

allow update to manual discovery file

re-load manual discovery file for each runthrough of the peer network, allowing incremental updates to the peer file even when exo is running
Ian Paul 11 mesiacov pred
rodič
commit
98118babae

+ 96 - 52
exo/networking/manual/manual_discovery.py

@@ -9,63 +9,107 @@ from exo.networking.peer_handle import PeerHandle
 
 
 class ManualDiscovery(Discovery):
-  def __init__(
-    self,
-    network_config_path: str,
-    node_id: str,
-    create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
-  ):
-    self.topology = NetworkTopology.from_path(network_config_path)
-    self.create_peer_handle = create_peer_handle
+    def __init__(
+        self,
+        network_config_path: str,
+        node_id: str,
+        create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
+    ):
+        self.network_config_path = network_config_path
+        self.node_id = node_id
+        self.create_peer_handle = create_peer_handle
 
-    if node_id not in self.topology.peers:
-      raise ValueError(
-        f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}"
-      )
+        if node_id not in self.topology.peers:
+            raise ValueError(
+                f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}"
+            )
 
-    self.listen_task = None
+        self.listen_task = None
+        self.known_peers: Dict[str, PeerHandle] = {}
 
-    self.known_peers: Dict[str, PeerHandle] = {}
-    self.peers_in_network: Dict[str, PeerConfig] = self.topology.peers
-    self.peers_in_network.pop(node_id)
+    async def start(self) -> None:
+        self.listen_task = asyncio.create_task(self.task_find_peers_from_config())
 
-  async def start(self) -> None:
-    self.listen_task = asyncio.create_task(self.task_find_peers_from_config())
+    async def stop(self) -> None:
+        if self.listen_task:
+            self.listen_task.cancel()
 
-  async def stop(self) -> None:
-    if self.listen_task:
-      self.listen_task.cancel()
+    async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
+        if wait_for_peers > 0:
+            while len(self.known_peers) < wait_for_peers:
+                if DEBUG_DISCOVERY >= 2:
+                    print(
+                        f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers..."
+                    )
+                await asyncio.sleep(0.1)
+        if DEBUG_DISCOVERY >= 2:
+            print(
+                f"Discovered peers: {[peer.id() for peer in self.known_peers.values()]}"
+            )
+        return list(self.known_peers.values())
 
-  async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
-    if wait_for_peers > 0:
-      while len(self.known_peers) < wait_for_peers:
-        if DEBUG_DISCOVERY >= 2: print(f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers...")
-        await asyncio.sleep(0.1)
-    if DEBUG_DISCOVERY >= 2: print(f"Discovered peers: {[peer.id() for peer in self.known_peers.values()]}")
-    return list(self.known_peers.values())
+    async def task_find_peers_from_config(self):
+        if DEBUG_DISCOVERY >= 2:
+            print("Starting task to find peers from config...")
+        while True:
+            peers = self._get_peers().items()
+            for peer_id, peer_config in peers:
+                try:
+                    if DEBUG_DISCOVERY >= 2:
+                        print(
+                            f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}"
+                        )
+                    peer = self.known_peers.get(peer_id)
+                    if not peer:
+                        if DEBUG_DISCOVERY >= 2:
+                            print(f"{peer_id=} not found in known peers. Adding.")
+                        peer = self.create_peer_handle(
+                            peer_id,
+                            f"{peer_config.address}:{peer_config.port}",
+                            peer_config.device_capabilities,
+                        )
+                        peer = self.create_peer_handle(
+                            peer_id,
+                            f"{peer_config.address}:{peer_config.port}",
+                            peer_config.device_capabilities,
+                        )
+                    is_healthy = await peer.health_check()
+                    if is_healthy:
+                        if DEBUG_DISCOVERY >= 2:
+                            print(
+                                f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy."
+                            )
+                        self.known_peers[peer_id] = peer
+                    else:
+                        if DEBUG_DISCOVERY >= 2:
+                            print(
+                                f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy."
+                            )
+                        try:
+                            del self.known_peers[peer_id]
+                        except KeyError:
+                            pass
+                except Exception as e:
+                    if DEBUG_DISCOVERY >= 2:
+                        print(
+                            f"Exception occured when attempting to add {peer_id=}: {e}"
+                        )
+            await asyncio.sleep(1.0)
 
-  async def task_find_peers_from_config(self):
-    if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...")
-    while True:
-      for peer_id, peer_config in self.peers_in_network.items():
-        try:
-          if DEBUG_DISCOVERY >= 2: print(f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}")
-          peer = self.known_peers.get(peer_id)
-          if not peer:
-            if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} not found in known peers. Adding.")
-            peer = self.create_peer_handle(peer_id, f"{peer_config.address}:{peer_config.port}", "MAN", peer_config.device_capabilities)
-          is_healthy = await peer.health_check()
-          if is_healthy:
-            if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")
-            self.known_peers[peer_id] = peer
-          else:
-            if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy.")
-            try:
-              del self.known_peers[peer_id]
-            except KeyError:
-              pass
-        except Exception as e:
-          if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}")
-      await asyncio.sleep(1.0)
+            if DEBUG_DISCOVERY >= 2:
+                print(
+                    f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}"
+                )
 
-      if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}")
+    def _get_peers(self):
+        topology = NetworkTopology.from_path(self.network_config_path)
+
+        if self.node_id not in topology.peers:
+            raise ValueError(
+                f"Node ID {self.node_id} not found in network config file {self.network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in topology.peers]}"
+            )
+
+        peers_in_network: Dict[str, PeerConfig] = topology.peers
+        peers_in_network.pop(self.node_id)
+
+        return peers_in_network

+ 1 - 1
exo/networking/manual/test_data/test_config.json

@@ -29,4 +29,4 @@
       }
     }
   }
-}
+}

+ 51 - 4
exo/networking/manual/test_manual_discovery.py

@@ -1,3 +1,4 @@
+import json
 import asyncio
 import unittest
 from unittest import mock
@@ -44,9 +45,9 @@ class TestManualDiscovery(unittest.IsolatedAsyncioTestCase):
 
   async def test_discovery(self):
     peers1 = await self.discovery1.discover_peers(wait_for_peers=1)
-    assert len(peers1) == 1
+    self.assertEqual(len(peers1), 1)
     peers2 = await self.discovery2.discover_peers(wait_for_peers=1)
-    assert len(peers2) == 1
+    self.assertEqual(len(peers2), 1)
 
     # connect has to be explicitly called after discovery
     self.peer1.connect.assert_not_called()
@@ -76,9 +77,9 @@ class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
 
   async def test_grpc_discovery(self):
     peers1 = await self.discovery1.discover_peers(wait_for_peers=1)
-    assert len(peers1) == 1
+    self.assertEqual(len(peers1), 1)
     peers2 = await self.discovery2.discover_peers(wait_for_peers=1)
-    assert len(peers2) == 1
+    self.assertEqual(len(peers2), 1)
 
     # Connect
     await peers1[0].connect()
@@ -98,6 +99,52 @@ class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
     self.assertFalse(await peers1[0].is_connected())
     self.assertFalse(await peers2[0].is_connected())
 
+  async def test_dynamic_config_update(self):
+    initial_peers = await self.discovery1.discover_peers(wait_for_peers=1)
+    self.assertEqual(len(initial_peers), 1)
+
+    # Save original config for cleanup
+    with open(root_path, "r") as f:
+      original_config = json.load(f)
+
+    try:
+      updated_config = {
+        "peers": {
+          **original_config["peers"],
+          "node3": {
+            "address": "localhost",
+            "port": 50053,
+            "device_capabilities": {"model": "Unknown Model", "chip": "Unknown Chip", "memory": 0, "flops": {"fp32": 0, "fp16": 0, "int8": 0}},
+          },
+        }
+      }
+
+      with open(root_path, "w") as f:
+        json.dump(updated_config, f, indent=2)
+
+      node3 = mock.AsyncMock(spec=Node)
+      server3 = GRPCServer(node3, "localhost", 50053)
+      await server3.start()
+
+      try:
+        # Wait for the config to be reloaded
+        await asyncio.sleep(1.5)
+
+        updated_peers = await self.discovery1.discover_peers(wait_for_peers=2)
+        self.assertEqual(len(updated_peers), 2)
+
+        for peer in updated_peers:
+          await peer.connect()
+          self.assertTrue(await peer.is_connected())
+
+      finally:
+        await server3.stop()
+
+    finally:
+      # Restore the original config file
+      with open(root_path, "w") as f:
+        json.dump(original_config, f, indent=2)
+
 
 if __name__ == "__main__":
   asyncio.run(unittest.main())