Kaynağa Gözat

tests for manual networking

Ian Paul 6 ay önce
ebeveyn
işleme
1970b9c89f

+ 17 - 0
exo/networking/manual/test_data/invalid_config.json

@@ -0,0 +1,17 @@
+{
+  "peers": {
+    "node1": {
+      "address": "localhost",
+      "device_capabilities": {
+        "model": "Unknown Model",
+        "chip": "Unknown Chip",
+        "memory": 0,
+        "flops": {
+          "fp32": 0,
+          "fp16": 0,
+          "int8": 0
+        }
+      }
+    }
+  }
+}

+ 0 - 0
exo/networking/manual/test_data/invalid_json.json


+ 32 - 0
exo/networking/manual/test_data/test_config.json

@@ -0,0 +1,32 @@
+{
+  "peers": {
+    "node1": {
+      "address": "localhost",
+      "port": 50051,
+      "device_capabilities": {
+        "model": "Unknown Model",
+        "chip": "Unknown Chip",
+        "memory": 0,
+        "flops": {
+          "fp32": 0,
+          "fp16": 0,
+          "int8": 0
+        }
+      }
+    },
+    "node2": {
+      "address": "localhost",
+      "port": 50052,
+      "device_capabilities": {
+        "model": "Unknown Model",
+        "chip": "Unknown Chip",
+        "memory": 0,
+        "flops": {
+          "fp32": 0,
+          "fp16": 0,
+          "int8": 0
+        }
+      }
+    }
+  }
+}

+ 18 - 0
exo/networking/manual/test_data/test_config_single_node.json

@@ -0,0 +1,18 @@
+{
+  "peers": {
+    "node1": {
+      "address": "localhost",
+      "port": 50051,
+      "device_capabilities": {
+        "model": "Unknown Model",
+        "chip": "Unknown Chip",
+        "memory": 0,
+        "flops": {
+          "fp32": 0,
+          "fp16": 0,
+          "int8": 0
+        }
+      }
+    }
+  }
+}

+ 103 - 0
exo/networking/manual/test_manual_discovery.py

@@ -0,0 +1,103 @@
+import asyncio
+import unittest
+from unittest import mock
+from exo.networking.manual.manual_discovery import ManualDiscovery
+from exo.networking.manual.network_topology_config import NetworkTopology
+from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
+from exo.networking.grpc.grpc_server import GRPCServer
+from exo.orchestration.node import Node
+
+root_path = "./exo/networking/manual/test_data/test_config.json"
+
+
+class TestSingleNodeManualDiscovery(unittest.IsolatedAsyncioTestCase):
+  async def asyncSetUp(self):
+    self.peer1 = mock.AsyncMock()
+    self.peer1.connect = mock.AsyncMock()
+    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, device_capabilities: self.peer1)
+    _ = self.discovery1.start()
+
+  async def asyncTearDown(self):
+    await self.discovery1.stop()
+
+  async def test_discovery(self):
+    peers1 = await self.discovery1.discover_peers(wait_for_peers=0)
+    assert len(peers1) == 0
+
+    self.peer1.connect.assert_not_called()
+
+
+class TestManualDiscovery(unittest.IsolatedAsyncioTestCase):
+  async def asyncSetUp(self):
+    self.peer1 = mock.AsyncMock()
+    self.peer2 = mock.AsyncMock()
+    self.peer1.connect = mock.AsyncMock()
+    self.peer2.connect = mock.AsyncMock()
+    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, device_capabilities: self.peer1)
+    self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, device_capabilities: self.peer2)
+    await self.discovery1.start()
+    await self.discovery2.start()
+
+  async def asyncTearDown(self):
+    await self.discovery1.stop()
+    await self.discovery2.stop()
+
+  async def test_discovery(self):
+    peers1 = await self.discovery1.discover_peers(wait_for_peers=1)
+    assert len(peers1) == 1
+    peers2 = await self.discovery2.discover_peers(wait_for_peers=1)
+    assert len(peers2) == 1
+
+    # connect has to be explicitly called after discovery
+    self.peer1.connect.assert_not_called()
+    self.peer2.connect.assert_not_called()
+
+
+class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
+  async def asyncSetUp(self):
+    config = NetworkTopology.from_path(root_path)
+
+    self.node1 = mock.AsyncMock(spec=Node)
+    self.node2 = mock.AsyncMock(spec=Node)
+    self.server1 = GRPCServer(self.node1, config.peers["node1"].address, config.peers["node1"].port)
+    self.server2 = GRPCServer(self.node2, config.peers["node2"].address, config.peers["node2"].port)
+    await self.server1.start()
+    await self.server2.start()
+    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
+    self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
+    await self.discovery1.start()
+    await self.discovery2.start()
+
+  async def asyncTearDown(self):
+    await self.discovery1.stop()
+    await self.discovery2.stop()
+    await self.server1.stop()
+    await self.server2.stop()
+
+  async def test_grpc_discovery(self):
+    peers1 = await self.discovery1.discover_peers(wait_for_peers=1)
+    assert len(peers1) == 1
+    peers2 = await self.discovery2.discover_peers(wait_for_peers=1)
+    assert len(peers2) == 1
+
+    # Connect
+    await peers1[0].connect()
+    await peers2[0].connect()
+    self.assertTrue(await peers1[0].is_connected())
+    self.assertTrue(await peers2[0].is_connected())
+
+    # Kill server1
+    await self.server1.stop()
+
+    self.assertTrue(await peers1[0].is_connected())
+    self.assertFalse(await peers2[0].is_connected())
+
+    # Kill server2
+    await self.server2.stop()
+
+    self.assertFalse(await peers1[0].is_connected())
+    self.assertFalse(await peers2[0].is_connected())
+
+
+if __name__ == "__main__":
+  asyncio.run(unittest.main())

+ 47 - 0
exo/networking/manual/test_network_topology_config.py

@@ -0,0 +1,47 @@
+import unittest
+import json
+from exo.networking.manual.network_topology_config import NetworkTopology
+
+root_path = "./exo/networking/manual/test_data/"
+
+
+class TestNetworkTopologyConfig(unittest.TestCase):
+  def test_from_path_invalid_path(self):
+    with self.assertRaises(FileNotFoundError) as e:
+      NetworkTopology.from_path("invalid_path")
+    self.assertEqual(e.exception.args[0], "Config file not found at invalid_path")
+
+  def test_from_path_invalid_json(self):
+    with self.assertRaises(json.JSONDecodeError) as e:
+      NetworkTopology.from_path(root_path + "invalid_json.json")
+    self.assertEqual(e.exception.args[0], "Error decoding JSON data from ./exo/networking/manual/test_data/invalid_json.json: Expecting value: line 1 column 1 (char 0): line 1 column 1 (char 0)")
+
+  def test_from_path_invalid_config(self):
+    with self.assertRaises(KeyError) as e:
+      NetworkTopology.from_path(root_path + "invalid_config.json")
+    self.assertEqual(e.exception.args[0], "Missing required key in config file: 'port'")
+
+  def test_from_path_valid(self):
+    config = NetworkTopology.from_path(root_path + "test_config.json")
+
+    self.assertEqual(config.peers["node1"].port, 50051)
+    self.assertEqual(config.peers["node1"].device_capabilities.model, "Unknown Model")
+    self.assertEqual(config.peers["node1"].address, "localhost")
+    self.assertEqual(config.peers["node1"].device_capabilities.chip, "Unknown Chip")
+    self.assertEqual(config.peers["node1"].device_capabilities.memory, 0)
+    self.assertEqual(config.peers["node1"].device_capabilities.flops.fp32, 0)
+    self.assertEqual(config.peers["node1"].device_capabilities.flops.fp16, 0)
+    self.assertEqual(config.peers["node1"].device_capabilities.flops.int8, 0)
+
+    self.assertEqual(config.peers["node2"].port, 50052)
+    self.assertEqual(config.peers["node2"].device_capabilities.model, "Unknown Model")
+    self.assertEqual(config.peers["node2"].address, "localhost")
+    self.assertEqual(config.peers["node2"].device_capabilities.chip, "Unknown Chip")
+    self.assertEqual(config.peers["node2"].device_capabilities.memory, 0)
+    self.assertEqual(config.peers["node2"].device_capabilities.flops.fp32, 0)
+    self.assertEqual(config.peers["node2"].device_capabilities.flops.fp16, 0)
+    self.assertEqual(config.peers["node2"].device_capabilities.flops.int8, 0)
+
+
+if __name__ == "__main__":
+  unittest.main()