test_node.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import unittest
  2. from unittest.mock import Mock, AsyncMock
  3. import numpy as np
  4. from .standard_node import StandardNode
  5. from networking.peer_handle import PeerHandle
  6. class TestNode(unittest.IsolatedAsyncioTestCase):
  7. def setUp(self):
  8. self.mock_inference_engine = AsyncMock()
  9. self.mock_server = AsyncMock()
  10. self.mock_server.start = AsyncMock()
  11. self.mock_server.stop = AsyncMock()
  12. self.mock_discovery = AsyncMock()
  13. self.mock_discovery.start = AsyncMock()
  14. self.mock_discovery.stop = AsyncMock()
  15. mock_peer1 = Mock(spec=PeerHandle)
  16. mock_peer1.id.return_value = "peer1"
  17. mock_peer2 = Mock(spec=PeerHandle)
  18. mock_peer2.id.return_value = "peer2"
  19. self.mock_discovery.discover_peers = AsyncMock(return_value=[mock_peer1, mock_peer2])
  20. self.node = StandardNode("test_node", self.mock_server, self.mock_inference_engine, "localhost", 50051, self.mock_discovery)
  21. async def asyncSetUp(self):
  22. await self.node.start()
  23. async def asyncTearDown(self):
  24. await self.node.stop()
  25. async def test_node_initialization(self):
  26. self.assertEqual(self.node.node_id, "test_node")
  27. self.assertEqual(self.node.host, "localhost")
  28. self.assertEqual(self.node.port, 50051)
  29. async def test_node_start(self):
  30. self.mock_server.start.assert_called_once_with("localhost", 50051)
  31. async def test_node_stop(self):
  32. await self.node.stop()
  33. self.mock_server.stop.assert_called_once()
  34. async def test_discover_and_connect_to_peers(self):
  35. await self.node.discover_and_connect_to_peers()
  36. self.assertEqual(len(self.node.peers), 2)
  37. self.assertIn("peer1", map(lambda p: p.id(), self.node.peers))
  38. self.assertIn("peer2", map(lambda p: p.id(), self.node.peers))
  39. async def test_process_tensor_calls_inference_engine(self):
  40. mock_peer = Mock()
  41. self.node.peers = [mock_peer]
  42. input_tensor = np.array([69, 1, 2])
  43. await self.node.process_tensor(input_tensor, None)
  44. self.node.inference_engine.process_shard.assert_called_once_with(input_tensor)