Просмотр исходного кода

formatting and fixing tests after rebasing

Ian Paul 4 месяцев назад
Родитель
Сommit
b003292b89
2 измененных файлов с 158 добавлено и 175 удалено
  1. 2 2
      exo/networking/manual/manual_discovery.py
  2. 156 173
      exo/networking/manual/test_manual_discovery.py

+ 2 - 2
exo/networking/manual/manual_discovery.py

@@ -15,7 +15,7 @@ class ManualDiscovery(Discovery):
     self,
     self,
     network_config_path: str,
     network_config_path: str,
     node_id: str,
     node_id: str,
-    create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
+    create_peer_handle: Callable[[str, str, str, DeviceCapabilities], PeerHandle],
   ):
   ):
     self.network_config_path = network_config_path
     self.network_config_path = network_config_path
     self.node_id = node_id
     self.node_id = node_id
@@ -54,7 +54,7 @@ class ManualDiscovery(Discovery):
           peer = self.known_peers.get(peer_id)
           peer = self.known_peers.get(peer_id)
           if not peer:
           if not peer:
             if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} not found in known peers. Adding.")
             if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} not found in known peers. Adding.")
-            peer = self.create_peer_handle(peer_id, f"{peer_config.address}:{peer_config.port}", peer_config.device_capabilities)
+            peer = self.create_peer_handle(peer_id, f"{peer_config.address}:{peer_config.port}", "MAN", peer_config.device_capabilities)
           is_healthy = await peer.health_check()
           is_healthy = await peer.health_check()
           if is_healthy:
           if is_healthy:
             if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")
             if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")

+ 156 - 173
exo/networking/manual/test_manual_discovery.py

@@ -12,187 +12,170 @@ root_path = "./exo/networking/manual/test_data/test_config.json"
 
 
 
 
 class TestSingleNodeManualDiscovery(unittest.IsolatedAsyncioTestCase):
 class TestSingleNodeManualDiscovery(unittest.IsolatedAsyncioTestCase):
-    async def asyncSetUp(self):
-        self.peer1 = mock.AsyncMock(spec=Node)
-        self.peer1.connect = mock.AsyncMock()
-        self.server1 = GRPCServer(self.peer1, "localhost", 8000)
-        await self.server1.start()
-        self.discovery1 = ManualDiscovery(
-            root_path,
-            "node1",
-            create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(
-                peer_id, address, device_capabilities
-            ),
-        )
-        await self.discovery1.start()
-
-    async def asyncTearDown(self):
-        await self.discovery1.stop()
-        await self.server1.stop()
-
-    async def test_discovery(self):
-        peers1 = await self.discovery1.discover_peers(wait_for_peers=0)
-        self.assertEqual(len(peers1), 0)
-
-        self.peer1.connect.assert_not_called()
+  async def asyncSetUp(self):
+    self.peer1 = mock.AsyncMock()
+    self.peer1.connect = mock.AsyncMock()
+    self.discovery1 = ManualDiscovery(
+      root_path,
+      "node1",
+      create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1,
+    )
+    await self.discovery1.start()
+
+  async def asyncTearDown(self):
+    await self.discovery1.stop()
+
+  async def test_discovery(self):
+    peers1 = await self.discovery1.discover_peers(wait_for_peers=0)
+    assert len(peers1) == 0
+
+    self.peer1.connect.assert_not_called()
 
 
 
 
 class TestManualDiscovery(unittest.IsolatedAsyncioTestCase):
 class TestManualDiscovery(unittest.IsolatedAsyncioTestCase):
-    async def asyncSetUp(self):
-        self.peer1 = mock.AsyncMock()
-        self.peer2 = mock.AsyncMock()
-
-        self.peer1.id = mock.MagicMock(return_value="node2")
-        self.peer2.id = mock.MagicMock(return_value="node1")
-
-        self.peer1.connect = mock.AsyncMock()
-        self.peer2.connect = mock.AsyncMock()
-        self.discovery1 = ManualDiscovery(
-            root_path,
-            "node1",
-            create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1,
-        )
-        self.discovery2 = ManualDiscovery(
-            root_path,
-            "node2",
-            create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2,
-        )
-        await self.discovery1.start()
-        await self.discovery2.start()
-
-    async def asyncTearDown(self):
-        await self.discovery1.stop()
-        await self.discovery2.stop()
-
-    async def test_discovery(self):
-        peers1 = await self.discovery1.discover_peers(wait_for_peers=1)
-        self.assertEqual(len(peers1), 1)
-        peers2 = await self.discovery2.discover_peers(wait_for_peers=1)
-        self.assertEqual(len(peers2), 1)
-
-        # connect has to be explicitly called after discovery
-        self.peer1.connect.assert_not_called()
-        self.peer2.connect.assert_not_called()
+  async def asyncSetUp(self):
+    self.peer1 = mock.AsyncMock()
+    self.peer2 = mock.AsyncMock()
+    self.peer1.connect = mock.AsyncMock()
+    self.peer2.connect = mock.AsyncMock()
+    self.discovery1 = ManualDiscovery(
+      root_path,
+      "node1",
+      create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1,
+    )
+    self.discovery2 = ManualDiscovery(
+      root_path,
+      "node2",
+      create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2,
+    )
+    await self.discovery1.start()
+    await self.discovery2.start()
+
+  async def asyncTearDown(self):
+    await self.discovery1.stop()
+    await self.discovery2.stop()
+
+  async def test_discovery(self):
+    peers1 = await self.discovery1.discover_peers(wait_for_peers=1)
+    assert len(peers1) == 1
+    peers2 = await self.discovery2.discover_peers(wait_for_peers=1)
+    assert len(peers2) == 1
+
+    # connect has to be explicitly called after discovery
+    self.peer1.connect.assert_not_called()
+    self.peer2.connect.assert_not_called()
 
 
 
 
 class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
 class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
-    async def asyncSetUp(self):
-        config = NetworkTopology.from_path(root_path)
-
-        self.node1 = mock.AsyncMock(spec=Node)
-        self.node2 = mock.AsyncMock(spec=Node)
-        self.server1 = GRPCServer(
-            self.node1, config.peers["node1"].address, config.peers["node1"].port
-        )
-        self.server2 = GRPCServer(
-            self.node2, config.peers["node2"].address, config.peers["node2"].port
-        )
-        await self.server1.start()
-        await self.server2.start()
-        self.discovery1 = ManualDiscovery(
-            root_path,
-            "node1",
-            create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(
-                peer_id, address, description, device_capabilities
-            ),
-        )
-        self.discovery2 = ManualDiscovery(
-            root_path,
-            "node2",
-            create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(
-                peer_id, address, description, device_capabilities
-            ),
-        )
-        await self.discovery1.start()
-        await self.discovery2.start()
-
-    async def asyncTearDown(self):
-        await self.discovery1.stop()
-        await self.discovery2.stop()
-        await self.server1.stop()
-        await self.server2.stop()
-
-    async def test_grpc_discovery(self):
-        peers1 = await self.discovery1.discover_peers(wait_for_peers=1)
-        self.assertEqual(len(peers1), 1)
-        peers2 = await self.discovery2.discover_peers(wait_for_peers=1)
-        self.assertEqual(len(peers2), 1)
-
-        # Connect
-        await peers1[0].connect()
-        await peers2[0].connect()
-        self.assertTrue(await peers1[0].is_connected())
-        self.assertTrue(await peers2[0].is_connected())
-
-        # Kill server1
-        await self.server1.stop()
-
-        self.assertTrue(await peers1[0].is_connected())
-        self.assertFalse(await peers2[0].is_connected())
-
-        # Kill server2
-        await self.server2.stop()
-
-        self.assertFalse(await peers1[0].is_connected())
-        self.assertFalse(await peers2[0].is_connected())
-
-    async def test_dynamic_config_update(self):
-        initial_peers = await self.discovery1.discover_peers(wait_for_peers=1)
-        self.assertEqual(len(initial_peers), 1)
-
-        # Save original config for cleanup
-        with open(root_path, "r") as f:
-            original_config = json.load(f)
-
-        try:
-            updated_config = {
-                "peers": {
-                    **original_config["peers"],
-                    "node3": {
-                        "address": "localhost",
-                        "port": 50053,
-                        "device_capabilities": {
-                            "model": "Unknown Model",
-                            "chip": "Unknown Chip",
-                            "memory": 0,
-                            "flops": {"fp32": 0, "fp16": 0, "int8": 0},
-                        },
-                    },
-                }
-            }
-
-            with open(root_path, "w") as f:
-                json.dump(updated_config, f, indent=2)
-
-            node3 = mock.AsyncMock(spec=Node)
-            server3 = GRPCServer(node3, "localhost", 50053)
-            await server3.start()
-
-            try:
-                # Wait for the config to be reloaded
-                await asyncio.sleep(1.5)
-
-                updated_peers = await self.discovery1.discover_peers(wait_for_peers=2)
-                self.assertEqual(len(updated_peers), 2)
-
-                for peer in updated_peers:
-                    await peer.connect()
-                    self.assertTrue(await peer.is_connected())
-
-            finally:
-                await server3.stop()
-
-        finally:
-            # Restore the original config file
-            with open(root_path, "w") as f:
-                json.dump(original_config, f, indent=2)
-
-        # Wait for the config to be reloaded again
+  async def asyncSetUp(self):
+    config = NetworkTopology.from_path(root_path)
+
+    self.node1 = mock.AsyncMock(spec=Node)
+    self.node2 = mock.AsyncMock(spec=Node)
+    self.server1 = GRPCServer(self.node1, config.peers["node1"].address, config.peers["node1"].port)
+    self.server2 = GRPCServer(self.node2, config.peers["node2"].address, config.peers["node2"].port)
+    await self.server1.start()
+    await self.server2.start()
+    self.discovery1 = ManualDiscovery(
+      root_path,
+      "node1",
+      create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
+    )
+    self.discovery2 = ManualDiscovery(
+      root_path,
+      "node2",
+      create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
+    )
+    await self.discovery1.start()
+    await self.discovery2.start()
+
+  async def asyncTearDown(self):
+    await self.discovery1.stop()
+    await self.discovery2.stop()
+    await self.server1.stop()
+    await self.server2.stop()
+
+  async def test_grpc_discovery(self):
+    peers1 = await self.discovery1.discover_peers(wait_for_peers=1)
+    assert len(peers1) == 1
+    peers2 = await self.discovery2.discover_peers(wait_for_peers=1)
+    assert len(peers2) == 1
+
+    # Connect
+    await peers1[0].connect()
+    await peers2[0].connect()
+    self.assertTrue(await peers1[0].is_connected())
+    self.assertTrue(await peers2[0].is_connected())
+
+    # Kill server1
+    await self.server1.stop()
+
+    self.assertTrue(await peers1[0].is_connected())
+    self.assertFalse(await peers2[0].is_connected())
+
+    # Kill server2
+    await self.server2.stop()
+
+    self.assertFalse(await peers1[0].is_connected())
+    self.assertFalse(await peers2[0].is_connected())
+
+  async def test_dynamic_config_update(self):
+    initial_peers = await self.discovery1.discover_peers(wait_for_peers=1)
+    self.assertEqual(len(initial_peers), 1)
+
+    # Save original config for cleanup
+    with open(root_path, "r") as f:
+      original_config = json.load(f)
+
+    try:
+      updated_config = {
+        "peers": {
+          **original_config["peers"],
+          "node3": {
+            "address": "localhost",
+            "port": 50053,
+            "device_capabilities": {
+              "model": "Unknown Model",
+              "chip": "Unknown Chip",
+              "memory": 0,
+              "flops": {"fp32": 0, "fp16": 0, "int8": 0},
+            },
+          },
+        }
+      }
+
+      with open(root_path, "w") as f:
+        json.dump(updated_config, f, indent=2)
+
+      node3 = mock.AsyncMock(spec=Node)
+      server3 = GRPCServer(node3, "localhost", 50053)
+      await server3.start()
+
+      try:
+        # Wait for the config to be reloaded
         await asyncio.sleep(1.5)
         await asyncio.sleep(1.5)
 
 
-        updated_peers = await self.discovery1.discover_peers(wait_for_peers=1)
-        self.assertEqual(len(updated_peers), 1)
+        updated_peers = await self.discovery1.discover_peers(wait_for_peers=2)
+        self.assertEqual(len(updated_peers), 2)
+
+        for peer in updated_peers:
+          await peer.connect()
+          self.assertTrue(await peer.is_connected())
+
+      finally:
+        await server3.stop()
+
+    finally:
+      # Restore the original config file
+      with open(root_path, "w") as f:
+        json.dump(original_config, f, indent=2)
+
+    # Wait for the config to be reloaded again
+    await asyncio.sleep(1.5)
+
+    updated_peers = await self.discovery1.discover_peers(wait_for_peers=1)
+    self.assertEqual(len(updated_peers), 1)
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
-    asyncio.run(unittest.main())
+  asyncio.run(unittest.main())