Bläddra i källkod

cleaner discovery

Alex Cheema 8 månader sedan
förälder
incheckning
baf6efd321

+ 2 - 8
exo/networking/grpc/grpc_peer_handle.py

@@ -1,6 +1,5 @@
 import grpc
 import numpy as np
-import asyncio
 from typing import Optional, Tuple, List
 
 # These would be generated from the .proto file
@@ -27,15 +26,10 @@ class GRPCPeerHandle(PeerHandle):
   def device_capabilities(self) -> DeviceCapabilities:
     return self._device_capabilities
 
-  async def connect(self, timeout: float = 5.0):
+  async def connect(self):
     self.channel = grpc.aio.insecure_channel(self.address, options=[("grpc.max_metadata_size", 32*1024*1024)])
     self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
-    try:
-      async with asyncio.timeout(timeout): await self.channel.channel_ready()
-    except asyncio.TimeoutError:
-      print("Connection attempt timed out")
-      await self.disconnect()
-      raise
+    await self.channel.channel_ready()
 
   async def is_connected(self) -> bool:
     return self.channel is not None and self.channel.get_state() == grpc.ChannelConnectivity.READY

+ 2 - 2
exo/networking/test_udp_discovery.py

@@ -40,8 +40,8 @@ class TestUDPDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
     self.server2 = GRPCServer(self.node2, "localhost", 50054)
     await self.server1.start()
     await self.server2.start()
-    self.discovery1 = UDPDiscovery("discovery1", 50053, 5678, 5679)
-    self.discovery2 = UDPDiscovery("discovery2", 50054, 5679, 5678)
+    self.discovery1 = UDPDiscovery("discovery1", 50053, 5678, 5679, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
+    self.discovery2 = UDPDiscovery("discovery2", 50054, 5679, 5678, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
     await self.discovery1.start()
     await self.discovery2.start()
 

+ 21 - 60
exo/networking/udp_discovery.py

@@ -6,7 +6,6 @@ import traceback
 from typing import List, Dict, Callable, Tuple, Coroutine, Type
 from .discovery import Discovery
 from .peer_handle import PeerHandle
-from .grpc.grpc_peer_handle import GRPCPeerHandle
 from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
 from exo import DEBUG_DISCOVERY
 
@@ -30,11 +29,11 @@ class UDPDiscovery(Discovery):
     node_id: str,
     node_port: int,
     listen_port: int,
+    create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
     broadcast_port: int = None,
     broadcast_interval: int = 1,
     device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
     discovery_timeout: int = 30,
-    create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle] = lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities),
   ):
     self.node_id = node_id
     self.node_port = node_port
@@ -56,43 +55,18 @@ class UDPDiscovery(Discovery):
     self.cleanup_task = asyncio.create_task(self.task_cleanup_peers())
 
   async def stop(self):
-    if self.broadcast_task:
-      self.broadcast_task.cancel()
-    if self.listen_task:
-      self.listen_task.cancel()
-    if self.cleanup_task:
-      self.cleanup_task.cancel()
+    if self.broadcast_task: self.broadcast_task.cancel()
+    if self.listen_task: self.listen_task.cancel()
+    if self.cleanup_task: self.cleanup_task.cancel()
     if self.broadcast_task or self.listen_task or self.cleanup_task:
       await asyncio.gather(self.broadcast_task, self.listen_task, self.cleanup_task, return_exceptions=True)
 
   async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
-    if DEBUG_DISCOVERY >= 2:
-      print("Starting peer discovery process...")
-
+    if DEBUG_DISCOVERY >= 2: print("Starting peer discovery process...")
     if wait_for_peers > 0:
-      while len(self.known_peers) == 0:
-        if DEBUG_DISCOVERY >= 2:
-          print("No peers discovered yet, retrying in 1 second...")
-        await asyncio.sleep(0.1)  # Keep trying to find peers
-      if DEBUG_DISCOVERY >= 2:
-        print(f"Discovered first peer: {next(iter(self.known_peers.values()))}")
-
-    grace_period = 5  # seconds
-    while True:
-      initial_peer_count = len(self.known_peers)
-      if DEBUG_DISCOVERY >= 2:
-        print(f"Current number of known peers: {initial_peer_count}. Waiting {grace_period} seconds to discover more...")
-      if len(self.known_peers) == initial_peer_count:
-        if wait_for_peers > 0:
-          await asyncio.sleep(grace_period)
-          if DEBUG_DISCOVERY >= 2:
-            print(f"Waiting additional {wait_for_peers} seconds for more peers.")
-          wait_for_peers = 0
-        else:
-          if DEBUG_DISCOVERY >= 2:
-            print("No new peers discovered in the last grace period. Ending discovery process.")
-          break  # No new peers found in the grace period, we are done
-
+      while len(self.known_peers) < wait_for_peers:
+        if DEBUG_DISCOVERY >= 2: print(f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers...")
+        await asyncio.sleep(0.1)
     return [peer_handle for peer_handle, _, _ in self.known_peers.values()]
 
   async def task_broadcast_presence(self):
@@ -109,13 +83,13 @@ class UDPDiscovery(Discovery):
 
     while True:
       try:
-        if DEBUG_DISCOVERY >= 3:
-          print(f"Broadcast presence: {message}")
+        if DEBUG_DISCOVERY >= 3: print(f"Broadcast presence: {message}")
         transport.sendto(message, ("<broadcast>", self.broadcast_port))
-        await asyncio.sleep(self.broadcast_interval)
       except Exception as e:
         print(f"Error in broadcast presence: {e}")
         print(traceback.format_exc())
+      finally:
+        await asyncio.sleep(self.broadcast_interval)
 
   async def on_listen_message(self, data, addr):
     if not data:
@@ -125,20 +99,17 @@ class UDPDiscovery(Discovery):
 
     # Check if the decoded data starts with a valid JSON character
     if not (decoded_data.strip() and decoded_data.strip()[0] in "{["):
-      if DEBUG_DISCOVERY >= 2:
-        print(f"Received invalid JSON data from {addr}: {decoded_data[:100]}")
+      if DEBUG_DISCOVERY >= 2: print(f"Received invalid JSON data from {addr}: {decoded_data[:100]}")
       return
 
     try:
       decoder = json.JSONDecoder(strict=False)
       message = decoder.decode(decoded_data)
     except json.JSONDecodeError as e:
-      if DEBUG_DISCOVERY >= 2:
-        print(f"Error decoding JSON data from {addr}: {e}")
+      if DEBUG_DISCOVERY >= 2: print(f"Error decoding JSON data from {addr}: {e}")
       return
 
-    if DEBUG_DISCOVERY >= 2:
-      print(f"received from peer {addr}: {message}")
+    if DEBUG_DISCOVERY >= 2: print(f"received from peer {addr}: {message}")
 
     if message["type"] == "discovery" and message["node_id"] != self.node_id:
       peer_id = message["node_id"]
@@ -151,14 +122,12 @@ class UDPDiscovery(Discovery):
           time.time(),
           time.time(),
         )
-        if DEBUG_DISCOVERY >= 2:
-          print(f"Discovered new peer {peer_id} at {peer_host}:{peer_port}")
+        if DEBUG_DISCOVERY >= 2: print(f"Discovered new peer {peer_id} at {peer_host}:{peer_port}")
       self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time())
 
   async def task_listen_for_peers(self):
     await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message), local_addr=("0.0.0.0", self.listen_port))
-    if DEBUG_DISCOVERY >= 2:
-      print("Started listen task")
+    if DEBUG_DISCOVERY >= 2: print("Started listen task")
 
   async def task_cleanup_peers(self):
     while True:
@@ -168,20 +137,12 @@ class UDPDiscovery(Discovery):
           peer_handle.id() for peer_handle, connected_at, last_seen in self.known_peers.values()
           if (not await peer_handle.is_connected() and current_time - connected_at > self.discovery_timeout) or current_time - last_seen > self.discovery_timeout
         ]
-        if DEBUG_DISCOVERY >= 2:
-          print(
-            "Peer statuses:",
-            {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}"
-             for peer_handle, connected_at, last_seen in self.known_peers.values()},
-          )
-        if DEBUG_DISCOVERY >= 2 and len(peers_to_remove) > 0:
-          print(f"Cleaning up peers: {peers_to_remove}")
+        if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}" for peer_handle, connected_at, last_seen in self.known_peers.values()})
         for peer_id in peers_to_remove:
-          if peer_id in self.known_peers:
-            del self.known_peers[peer_id]
-          if DEBUG_DISCOVERY >= 2:
-            print(f"Removed peer {peer_id} due to inactivity.")
-        await asyncio.sleep(self.broadcast_interval)
+          if peer_id in self.known_peers: del self.known_peers[peer_id]
+          if DEBUG_DISCOVERY >= 2: print(f"Removed peer {peer_id} due to inactivity.")
       except Exception as e:
         print(f"Error in cleanup peers: {e}")
         print(traceback.format_exc())
+      finally:
+        await asyncio.sleep(self.broadcast_interval)

+ 2 - 1
main.py

@@ -8,6 +8,7 @@ import uuid
 from exo.orchestration.standard_node import StandardNode
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.udp_discovery import UDPDiscovery
+from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 from exo.api import ChatGPTAPI
 from exo.download.shard_download import ShardDownloader, RepoProgressEvent
@@ -66,7 +67,7 @@ if DEBUG >= 0:
   for chatgpt_api_endpoint in chatgpt_api_endpoints:
     print(f" - {terminal_link(chatgpt_api_endpoint)}")
 
-discovery = UDPDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, discovery_timeout=args.discovery_timeout)
+discovery = UDPDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities), discovery_timeout=args.discovery_timeout)
 topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None
 node = StandardNode(
   args.node_id,