Просмотр исходного кода

generalise UDPDiscovery to any kind of PeerHandle that accepts an address. test it

Alex Cheema 8 месяцев назад
Родитель
Сommit
f93f811dcb

+ 8 - 1
exo/networking/grpc/grpc_peer_handle.py

@@ -1,5 +1,6 @@
 import grpc
 import numpy as np
+import asyncio
 from typing import Optional, Tuple, List
 
 # These would be generated from the .proto file
@@ -26,9 +27,15 @@ class GRPCPeerHandle(PeerHandle):
   def device_capabilities(self) -> DeviceCapabilities:
     return self._device_capabilities
 
-  async def connect(self):
+  async def connect(self, timeout: float = 5.0):
     self.channel = grpc.aio.insecure_channel(self.address, options=[("grpc.max_metadata_size", 32*1024*1024)])
     self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
+    try:
+      async with asyncio.timeout(timeout): await self.channel.channel_ready()
+    except asyncio.TimeoutError:
+      print("Connection attempt timed out")
+      await self.disconnect()
+      raise
 
   async def is_connected(self) -> bool:
     return self.channel is not None and self.channel.get_state() == grpc.ChannelConnectivity.READY

+ 0 - 22
exo/networking/grpc/test_grpc_discovery.py

@@ -1,22 +0,0 @@
-import asyncio
-import unittest
-from ..udp_discovery import UDPDiscovery
-
-
-class TestUDPDiscovery(unittest.IsolatedAsyncioTestCase):
-  async def asyncSetUp(self):
-    self.node1 = UDPDiscovery("node1", 50051, 5678, 5679)
-    self.node2 = UDPDiscovery("node2", 50052, 5679, 5678)
-    await self.node1.start()
-    await self.node2.start()
-
-  async def asyncTearDown(self):
-    await self.node1.stop()
-    await self.node2.stop()
-
-  async def test_discovery(self):
-    await asyncio.sleep(4)
-
-    # Check discovered peers
-    print("Node1 Peers:", ", ".join([f"{peer_id}: {peer}" for peer_id, peer in self.node1.known_peers.items()]))
-    print("Node2 Peers:", ", ".join([f"{peer_id}: {peer}" for peer_id, peer in self.node2.known_peers.items()]))

+ 32 - 0
exo/networking/test_udp_discovery.py

@@ -0,0 +1,32 @@
+import asyncio
+import unittest
+from unittest import mock  # Add this import
+from exo.networking.udp_discovery import UDPDiscovery
+
+class TestUDPDiscovery(unittest.IsolatedAsyncioTestCase):
+  async def asyncSetUp(self):
+    self.peer1 = mock.AsyncMock()
+    self.peer2 = mock.AsyncMock()
+    self.peer1.connect = mock.AsyncMock()
+    self.peer2.connect = mock.AsyncMock()
+    self.discovery1 = UDPDiscovery("discovery1", 50051, 5678, 5679, create_peer_handle=lambda peer_id, address, device_capabilities: self.peer1)
+    self.discovery2 = UDPDiscovery("discovery2", 50052, 5679, 5678, create_peer_handle=lambda peer_id, address, device_capabilities: self.peer2)
+    await self.discovery1.start()
+    await self.discovery2.start()
+
+  async def asyncTearDown(self):
+    await self.discovery1.stop()
+    await self.discovery2.stop()
+
+  async def test_discovery(self):
+    peers1 = await self.discovery1.discover_peers(wait_for_peers=1)
+    assert len(peers1) == 1
+    peers2 = await self.discovery2.discover_peers(wait_for_peers=1)
+    assert len(peers2) == 1
+
+    # connect has to be explicitly called after discovery
+    self.peer1.connect.assert_not_called()
+    self.peer2.connect.assert_not_called()
+
+if __name__ == "__main__":
+  asyncio.run(unittest.main())

+ 6 - 4
exo/networking/udp_discovery.py

@@ -3,7 +3,7 @@ import json
 import socket
 import time
 import traceback
-from typing import List, Dict, Callable, Tuple, Coroutine
+from typing import List, Dict, Callable, Tuple, Coroutine, Type
 from .discovery import Discovery
 from .peer_handle import PeerHandle
 from .grpc.grpc_peer_handle import GRPCPeerHandle
@@ -34,6 +34,7 @@ class UDPDiscovery(Discovery):
     broadcast_interval: int = 1,
     device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
     discovery_timeout: int = 30,
+    create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle] = lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities),
   ):
     self.node_id = node_id
     self.node_port = node_port
@@ -41,11 +42,12 @@ class UDPDiscovery(Discovery):
     self.listen_port = listen_port
     self.broadcast_port = broadcast_port if broadcast_port is not None else listen_port
     self.broadcast_interval = broadcast_interval
-    self.known_peers: Dict[str, Tuple[GRPCPeerHandle, float, float]] = {}
+    self.known_peers: Dict[str, Tuple[PeerHandle, float, float]] = {}
     self.broadcast_task = None
     self.listen_task = None
     self.cleanup_task = None
     self.discovery_timeout = discovery_timeout
+    self.create_peer_handle = create_peer_handle
 
   async def start(self):
     self.device_capabilities = device_capabilities()
@@ -71,7 +73,7 @@ class UDPDiscovery(Discovery):
       while len(self.known_peers) == 0:
         if DEBUG_DISCOVERY >= 2:
           print("No peers discovered yet, retrying in 1 second...")
-        await asyncio.sleep(1)  # Keep trying to find peers
+        await asyncio.sleep(0.1)  # Keep trying to find peers
       if DEBUG_DISCOVERY >= 2:
         print(f"Discovered first peer: {next(iter(self.known_peers.values()))}")
 
@@ -145,7 +147,7 @@ class UDPDiscovery(Discovery):
       device_capabilities = DeviceCapabilities(**message["device_capabilities"])
       if peer_id not in self.known_peers:
         self.known_peers[peer_id] = (
-          GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities),
+          self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities),
           time.time(),
           time.time(),
         )

+ 3 - 3
exo/orchestration/standard_node.py

@@ -281,9 +281,9 @@ class StandardNode(Node):
     self.peers = await self.discovery.discover_peers(wait_for_peers)
     for peer in self.peers:
       is_connected = await peer.is_connected()
-      if DEBUG >= 2 and is_connected:
-        print(f"Already connected to {peer.id()}: {is_connected}")
-      if not is_connected:
+      if is_connected:
+        if DEBUG >= 2: print(f"Already connected to {peer.id()}: {is_connected}")
+      else:
         if DEBUG >= 2: print(f"Connecting to {peer.id()}...")
         await peer.connect()
         if DEBUG >= 1: print(f"Connected to peer {peer.device_capabilities()} ({peer.id()=})")