test_udp_discovery.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import asyncio
  2. import unittest
  3. from unittest import mock
  4. from exo.networking.udp.udp_discovery import UDPDiscovery
  5. from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
  6. from exo.networking.grpc.grpc_server import GRPCServer
  7. from exo.orchestration.node import Node
  8. class TestUDPDiscovery(unittest.IsolatedAsyncioTestCase):
  9. async def asyncSetUp(self):
  10. self.peer1 = mock.AsyncMock()
  11. self.peer2 = mock.AsyncMock()
  12. self.peer1.connect = mock.AsyncMock()
  13. self.peer2.connect = mock.AsyncMock()
  14. self.discovery1 = UDPDiscovery("discovery1", 50051, 5678, 5679, create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
  15. self.discovery2 = UDPDiscovery("discovery2", 50052, 5679, 5678, create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2)
  16. await self.discovery1.start()
  17. await self.discovery2.start()
  18. async def asyncTearDown(self):
  19. await self.discovery1.stop()
  20. await self.discovery2.stop()
  21. async def test_discovery(self):
  22. peers1 = await self.discovery1.discover_peers(wait_for_peers=1)
  23. assert len(peers1) == 1
  24. peers2 = await self.discovery2.discover_peers(wait_for_peers=1)
  25. assert len(peers2) == 1
  26. # connect has to be explicitly called after discovery
  27. self.peer1.connect.assert_not_called()
  28. self.peer2.connect.assert_not_called()
  29. class TestUDPDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
  30. async def asyncSetUp(self):
  31. self.node1 = mock.AsyncMock(spec=Node)
  32. self.node2 = mock.AsyncMock(spec=Node)
  33. self.server1 = GRPCServer(self.node1, "localhost", 50053)
  34. self.server2 = GRPCServer(self.node2, "localhost", 50054)
  35. await self.server1.start()
  36. await self.server2.start()
  37. self.discovery1 = UDPDiscovery("discovery1", 50053, 5678, 5679, lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
  38. self.discovery2 = UDPDiscovery("discovery2", 50054, 5679, 5678, lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
  39. await self.discovery1.start()
  40. await self.discovery2.start()
  41. async def asyncTearDown(self):
  42. await self.discovery1.stop()
  43. await self.discovery2.stop()
  44. await self.server1.stop()
  45. await self.server2.stop()
  46. async def test_grpc_discovery(self):
  47. peers1 = await self.discovery1.discover_peers(wait_for_peers=1)
  48. assert len(peers1) == 1
  49. peers2 = await self.discovery2.discover_peers(wait_for_peers=1)
  50. assert len(peers2) == 1
  51. assert not await peers1[0].is_connected()
  52. assert not await peers2[0].is_connected()
  53. # Connect
  54. await peers1[0].connect()
  55. await peers2[0].connect()
  56. assert await peers1[0].is_connected()
  57. assert await peers2[0].is_connected()
  58. # Kill server1
  59. await self.server1.stop()
  60. assert await peers1[0].is_connected()
  61. assert not await peers2[0].is_connected()
  62. if __name__ == "__main__":
  63. asyncio.run(unittest.main())