瀏覽代碼

more robust discovery / peer handling. now we track if the same node id changes address, then we immediately conenct to it

Alex Cheema 11 月之前
父節點
當前提交
355c579965

+ 3 - 0
exo/networking/grpc/grpc_peer_handle.py

@@ -23,6 +23,9 @@ class GRPCPeerHandle(PeerHandle):
   def id(self) -> str:
     return self._id
 
+  def addr(self) -> str:
+    return self.address
+
   def device_capabilities(self) -> DeviceCapabilities:
     return self._device_capabilities
 

+ 4 - 0
exo/networking/peer_handle.py

@@ -10,6 +10,10 @@ class PeerHandle(ABC):
   def id(self) -> str:
     pass
 
+  @abstractmethod
+  def addr(self) -> str:
+    pass
+
   @abstractmethod
   def device_capabilities(self) -> DeviceCapabilities:
     pass

+ 3 - 3
exo/networking/udp_discovery.py

@@ -7,7 +7,7 @@ from typing import List, Dict, Callable, Tuple, Coroutine
 from .discovery import Discovery
 from .peer_handle import PeerHandle
 from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
-from exo import DEBUG_DISCOVERY
+from exo.helpers import DEBUG, DEBUG_DISCOVERY
 
 class ListenProtocol(asyncio.DatagramProtocol):
   def __init__(self, on_message: Callable[[bytes, Tuple[str, int]], Coroutine]):
@@ -115,13 +115,13 @@ class UDPDiscovery(Discovery):
       peer_host = addr[0]
       peer_port = message["grpc_port"]
       device_capabilities = DeviceCapabilities(**message["device_capabilities"])
-      if peer_id not in self.known_peers:
+      if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
+        if DEBUG >= 1: print(f"Adding {peer_id=} at {peer_host}:{peer_port}. Replace existing peer_id: {peer_id in self.known_peers}")
         self.known_peers[peer_id] = (
           self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities),
           time.time(),
           time.time(),
         )
-        if DEBUG_DISCOVERY >= 2: print(f"Discovered new peer {peer_id} at {peer_host}:{peer_port}")
       self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time())
 
   async def task_listen_for_peers(self):

+ 46 - 13
exo/orchestration/standard_node.py

@@ -50,7 +50,7 @@ class StandardNode(Node):
     await self.update_peers(wait_for_peers)
     await self.collect_topology()
     if DEBUG >= 2: print(f"Collected topology: {self.topology}")
-    asyncio.create_task(self.periodic_topology_collection(5))
+    asyncio.create_task(self.periodic_topology_collection(1.0))
 
   async def stop(self) -> None:
     await self.discovery.stop()
@@ -277,23 +277,56 @@ class StandardNode(Node):
       raise ValueError(f"No current partition found for node: {self.id}")
     return shards[current_partition_index]
 
-  async def update_peers(self, wait_for_peers: int = 0) -> None:
-    self.peers = await self.discovery.discover_peers(wait_for_peers)
-    for peer in self.peers:
-      is_connected = await peer.is_connected()
-      if is_connected:
-        if DEBUG >= 2: print(f"Already connected to {peer.id()}: {is_connected}")
-      else:
-        if DEBUG >= 2: print(f"Connecting to {peer.id()}...")
-        await peer.connect()
-        if DEBUG >= 1: print(f"Connected to peer {peer.device_capabilities()} ({peer.id()=})")
+  async def update_peers(self, wait_for_peers: int = 0) -> bool:
+    next_peers = await self.discovery.discover_peers(wait_for_peers)
+    current_peer_ids = {peer.id() for peer in self.peers}
+    next_peer_ids = {peer.id() for peer in next_peers}
+    peers_added = [peer for peer in next_peers if peer.id() not in current_peer_ids]
+    peers_removed = [peer for peer in self.peers if peer.id() not in next_peer_ids]
+    peers_updated = [
+      peer for peer in next_peers
+      if peer.id() in current_peer_ids and any(p.addr() != peer.addr() for p in self.peers if p.id() == peer.id())
+    ]
+    peers_unchanged = [
+      peer for peer in next_peers
+      if peer.id() in current_peer_ids and all(p.addr() == peer.addr() for p in self.peers if p.id() == peer.id())
+    ]
+    peers_to_disconnect = [peer for peer in peers_removed if await peer.is_connected()]
+    peers_to_connect = [peer for peer in peers_added + peers_updated + peers_unchanged if not await peer.is_connected()]
+
+    print(f"{peers_added=} {peers_removed=} {peers_updated=} {peers_unchanged=} {peers_to_disconnect=} {peers_to_connect=}")
+
+    async def disconnect_with_timeout(peer, timeout=5):
+      try:
+        await asyncio.wait_for(peer.disconnect(), timeout)
+      except Exception as e:
+        print(f"Error disconnecting peer {peer.id()}@{peer.addr()}: {e}")
+        traceback.print_exc()
+
+    async def connect_with_timeout(peer, timeout=5):
+      try:
+        await asyncio.wait_for(peer.connect(), timeout)
+      except Exception as e:
+        print(f"Error connecting peer {peer.id()}@{peer.addr()}: {e}")
+        traceback.print_exc()
+
+    await asyncio.gather(
+      *(disconnect_with_timeout(peer) for peer in peers_to_disconnect),
+      *(connect_with_timeout(peer) for peer in peers_to_connect),
+      return_exceptions=True
+    )
+
+    self.peers = next_peers
+    return len(peers_to_connect) > 0 or len(peers_to_disconnect) > 0
 
   async def periodic_topology_collection(self, interval: int):
     while True:
       await asyncio.sleep(interval)
       try:
-        await self.update_peers()
-        await self.collect_topology()
+        did_peers_change = await self.update_peers()
+        if DEBUG >= 2: print(f"{did_peers_change=}")
+        if did_peers_change:
+          await self.collect_topology()
       except Exception as e:
         print(f"Error collecting topology: {e}")
         traceback.print_exc()

+ 5 - 1
test/reconnect.sh

@@ -14,4 +14,8 @@ sleep 5
 echo "Starting node 2 again..."
 DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 --chatgpt-api-response-timeout-secs 900 > output3.log 2>&1 &
 PID2=$!
-kill $PID2
+sleep 5
+echo "Killing nodes and ending test..."
+kill $PID1
+kill $PID2
+echo "Test complete."