Przeglądaj źródła

Merge pull request #194 from exo-explore/better_networking

Better networking
Alex Cheema 8 miesięcy temu
rodzic
commit
35aba75be6

+ 3 - 3
.circleci/config.yml

@@ -17,11 +17,11 @@ commands:
             source env/bin/activate
             source env/bin/activate
 
 
             # Start first instance
             # Start first instance
-            HF_HOME="$(pwd)/.hf_cache_node1" DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --inference-engine <<parameters.inference_engine>> --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 --chatgpt-api-response-timeout-secs 900 > output1.log 2>&1 &
+            HF_HOME="$(pwd)/.hf_cache_node1" DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --inference-engine <<parameters.inference_engine>> --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 --chatgpt-api-response-timeout 900 2>&1 | tee output1.log &
             PID1=$!
             PID1=$!
 
 
             # Start second instance
             # Start second instance
-            HF_HOME="$(pwd)/.hf_cache_node2" DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --inference-engine <<parameters.inference_engine>> --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 --chatgpt-api-response-timeout-secs 900 > output2.log 2>&1 &
+            HF_HOME="$(pwd)/.hf_cache_node2" DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --inference-engine <<parameters.inference_engine>> --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 --chatgpt-api-response-timeout 900 2>&1 | tee output2.log &
             PID2=$!
             PID2=$!
 
 
             # Wait for discovery
             # Wait for discovery
@@ -144,7 +144,7 @@ jobs:
             PID2=$!
             PID2=$!
             sleep 10
             sleep 10
             kill $PID1 $PID2
             kill $PID1 $PID2
-            if grep -q "Connected to peer" output1.log && grep -q "Connected to peer" output2.log; then
+            if grep -q "Successfully connected peers: \['node2@.*:.*'\]" output1.log && ! grep -q "Failed to connect peers:" output1.log && grep -q "Successfully connected peers: \['node1@.*:.*'\]" output2.log && ! grep -q "Failed to connect peers:" output2.log; then
               echo "Test passed: Both instances discovered each other"
               echo "Test passed: Both instances discovered each other"
               exit 0
               exit 0
             else
             else

+ 5 - 5
exo/api/chatgpt_api.py

@@ -153,10 +153,10 @@ class PromptSession:
 
 
 
 
 class ChatGPTAPI:
 class ChatGPTAPI:
-  def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None):
+  def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None):
     self.node = node
     self.node = node
     self.inference_engine_classname = inference_engine_classname
     self.inference_engine_classname = inference_engine_classname
-    self.response_timeout_secs = response_timeout_secs
+    self.response_timeout = response_timeout
     self.on_chat_completion_request = on_chat_completion_request
     self.on_chat_completion_request = on_chat_completion_request
     self.app = web.Application(client_max_size=100*1024*1024)  # 100MB to support image upload
     self.app = web.Application(client_max_size=100*1024*1024)  # 100MB to support image upload
     self.prompts: PrefixDict[str, PromptSession] = PrefixDict()
     self.prompts: PrefixDict[str, PromptSession] = PrefixDict()
@@ -255,7 +255,7 @@ class ChatGPTAPI:
       return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
       return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
 
 
     try:
     try:
-      if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout_secs}s")
+      if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout}s")
 
 
       if stream:
       if stream:
         response = web.StreamResponse(
         response = web.StreamResponse(
@@ -304,7 +304,7 @@ class ChatGPTAPI:
 
 
           return _request_id == request_id and is_finished
           return _request_id == request_id and is_finished
 
 
-        _, tokens, _ = await callback.wait(on_result, timeout=self.response_timeout_secs)
+        _, tokens, _ = await callback.wait(on_result, timeout=self.response_timeout)
         if request_id in self.stream_tasks:  # in case there is still a stream task running, wait for it to complete
         if request_id in self.stream_tasks:  # in case there is still a stream task running, wait for it to complete
           if DEBUG >= 2: print("Pending stream task. Waiting for stream task to complete.")
           if DEBUG >= 2: print("Pending stream task. Waiting for stream task to complete.")
           try:
           try:
@@ -316,7 +316,7 @@ class ChatGPTAPI:
       else:
       else:
         _, tokens, _ = await callback.wait(
         _, tokens, _ = await callback.wait(
           lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished,
           lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished,
-          timeout=self.response_timeout_secs,
+          timeout=self.response_timeout,
         )
         )
 
 
         finish_reason = "length"
         finish_reason = "length"

+ 4 - 4
exo/download/hf/hf_helpers.py

@@ -234,10 +234,9 @@ async def download_repo_files(
           raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
           raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
         revision_info = await response.json()
         revision_info = await response.json()
         commit_hash = revision_info['sha']
         commit_hash = revision_info['sha']
-
-      # Cache the commit hash
-      async with aiofiles.open(refs_file, 'w') as f:
-        await f.write(commit_hash)
+        # Cache the commit hash
+        async with aiofiles.open(refs_file, 'w') as f:
+          await f.write(commit_hash)
 
 
   # Set up the snapshot directory
   # Set up the snapshot directory
   snapshot_dir = snapshots_dir/commit_hash
   snapshot_dir = snapshots_dir/commit_hash
@@ -400,4 +399,5 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
       shard_specific_patterns.append(sorted_file_names[-1])
       shard_specific_patterns.append(sorted_file_names[-1])
   else:
   else:
     shard_specific_patterns = ["*.safetensors"]
     shard_specific_patterns = ["*.safetensors"]
+  if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
   return list(set(default_patterns + shard_specific_patterns))  # Remove duplicates
   return list(set(default_patterns + shard_specific_patterns))  # Remove duplicates

+ 4 - 0
exo/networking/grpc/grpc_peer_handle.py

@@ -23,12 +23,16 @@ class GRPCPeerHandle(PeerHandle):
   def id(self) -> str:
   def id(self) -> str:
     return self._id
     return self._id
 
 
+  def addr(self) -> str:
+    return self.address
+
   def device_capabilities(self) -> DeviceCapabilities:
   def device_capabilities(self) -> DeviceCapabilities:
     return self._device_capabilities
     return self._device_capabilities
 
 
   async def connect(self):
   async def connect(self):
     self.channel = grpc.aio.insecure_channel(self.address, options=[("grpc.max_metadata_size", 32*1024*1024)])
     self.channel = grpc.aio.insecure_channel(self.address, options=[("grpc.max_metadata_size", 32*1024*1024)])
     self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
     self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
+    await self.channel.channel_ready()
 
 
   async def is_connected(self) -> bool:
   async def is_connected(self) -> bool:
     return self.channel is not None and self.channel.get_state() == grpc.ChannelConnectivity.READY
     return self.channel is not None and self.channel.get_state() == grpc.ChannelConnectivity.READY

+ 0 - 22
exo/networking/grpc/test_grpc_discovery.py

@@ -1,22 +0,0 @@
-import asyncio
-import unittest
-from .grpc_discovery import GRPCDiscovery
-
-
-class TestGRPCDiscovery(unittest.IsolatedAsyncioTestCase):
-  async def asyncSetUp(self):
-    self.node1 = GRPCDiscovery("node1", 50051, 5678, 5679)
-    self.node2 = GRPCDiscovery("node2", 50052, 5679, 5678)
-    await self.node1.start()
-    await self.node2.start()
-
-  async def asyncTearDown(self):
-    await self.node1.stop()
-    await self.node2.stop()
-
-  async def test_discovery(self):
-    await asyncio.sleep(4)
-
-    # Check discovered peers
-    print("Node1 Peers:", ", ".join([f"{peer_id}: {peer}" for peer_id, peer in self.node1.known_peers.items()]))
-    print("Node2 Peers:", ", ".join([f"{peer_id}: {peer}" for peer_id, peer in self.node2.known_peers.items()]))

+ 7 - 4
exo/networking/peer_handle.py

@@ -5,12 +5,15 @@ from exo.inference.shard import Shard
 from exo.topology.device_capabilities import DeviceCapabilities
 from exo.topology.device_capabilities import DeviceCapabilities
 from exo.topology.topology import Topology
 from exo.topology.topology import Topology
 
 
-
 class PeerHandle(ABC):
 class PeerHandle(ABC):
   @abstractmethod
   @abstractmethod
   def id(self) -> str:
   def id(self) -> str:
     pass
     pass
 
 
+  @abstractmethod
+  def addr(self) -> str:
+    pass
+
   @abstractmethod
   @abstractmethod
   def device_capabilities(self) -> DeviceCapabilities:
   def device_capabilities(self) -> DeviceCapabilities:
     pass
     pass
@@ -36,13 +39,13 @@ class PeerHandle(ABC):
     pass
     pass
 
 
   @abstractmethod
   @abstractmethod
-  async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
+  async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
     pass
     pass
 
 
   @abstractmethod
   @abstractmethod
-  async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
+  async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
     pass
     pass
 
 
   @abstractmethod
   @abstractmethod
-  async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
+  async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
     pass
     pass

+ 76 - 0
exo/networking/test_udp_discovery.py

@@ -0,0 +1,76 @@
+import asyncio
+import unittest
+from unittest import mock
+from exo.networking.udp_discovery import UDPDiscovery
+from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
+from exo.networking.grpc.grpc_server import GRPCServer
+from exo.orchestration.node import Node
+
+class TestUDPDiscovery(unittest.IsolatedAsyncioTestCase):
+  async def asyncSetUp(self):
+    self.peer1 = mock.AsyncMock()
+    self.peer2 = mock.AsyncMock()
+    self.peer1.connect = mock.AsyncMock()
+    self.peer2.connect = mock.AsyncMock()
+    self.discovery1 = UDPDiscovery("discovery1", 50051, 5678, 5679, create_peer_handle=lambda peer_id, address, device_capabilities: self.peer1)
+    self.discovery2 = UDPDiscovery("discovery2", 50052, 5679, 5678, create_peer_handle=lambda peer_id, address, device_capabilities: self.peer2)
+    await self.discovery1.start()
+    await self.discovery2.start()
+
+  async def asyncTearDown(self):
+    await self.discovery1.stop()
+    await self.discovery2.stop()
+
+  async def test_discovery(self):
+    peers1 = await self.discovery1.discover_peers(wait_for_peers=1)
+    assert len(peers1) == 1
+    peers2 = await self.discovery2.discover_peers(wait_for_peers=1)
+    assert len(peers2) == 1
+
+    # connect has to be explicitly called after discovery
+    self.peer1.connect.assert_not_called()
+    self.peer2.connect.assert_not_called()
+
+
+class TestUDPDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
+  async def asyncSetUp(self):
+    self.node1 = mock.AsyncMock(spec=Node)
+    self.node2 = mock.AsyncMock(spec=Node)
+    self.server1 = GRPCServer(self.node1, "localhost", 50053)
+    self.server2 = GRPCServer(self.node2, "localhost", 50054)
+    await self.server1.start()
+    await self.server2.start()
+    self.discovery1 = UDPDiscovery("discovery1", 50053, 5678, 5679, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
+    self.discovery2 = UDPDiscovery("discovery2", 50054, 5679, 5678, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
+    await self.discovery1.start()
+    await self.discovery2.start()
+
+  async def asyncTearDown(self):
+    await self.discovery1.stop()
+    await self.discovery2.stop()
+    await self.server1.stop()
+    await self.server2.stop()
+
+  async def test_grpc_discovery(self):
+    peers1 = await self.discovery1.discover_peers(wait_for_peers=1)
+    assert len(peers1) == 1
+    peers2 = await self.discovery2.discover_peers(wait_for_peers=1)
+    assert len(peers2) == 1
+    assert not await peers1[0].is_connected()
+    assert not await peers2[0].is_connected()
+
+    # Connect
+    await peers1[0].connect()
+    await peers2[0].connect()
+    assert await peers1[0].is_connected()
+    assert await peers2[0].is_connected()
+
+    # Kill server1
+    await self.server1.stop()
+
+    assert await peers1[0].is_connected()
+    assert not await peers2[0].is_connected()
+
+
+if __name__ == "__main__":
+  asyncio.run(unittest.main())

+ 50 - 87
exo/networking/grpc/grpc_discovery.py → exo/networking/udp_discovery.py

@@ -2,13 +2,12 @@ import asyncio
 import json
 import json
 import socket
 import socket
 import time
 import time
+import traceback
 from typing import List, Dict, Callable, Tuple, Coroutine
 from typing import List, Dict, Callable, Tuple, Coroutine
-from ..discovery import Discovery
-from ..peer_handle import PeerHandle
-from .grpc_peer_handle import GRPCPeerHandle
+from .discovery import Discovery
+from .peer_handle import PeerHandle
 from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
 from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
-from exo import DEBUG_DISCOVERY
-
+from exo.helpers import DEBUG, DEBUG_DISCOVERY
 
 
 class ListenProtocol(asyncio.DatagramProtocol):
 class ListenProtocol(asyncio.DatagramProtocol):
   def __init__(self, on_message: Callable[[bytes, Tuple[str, int]], Coroutine]):
   def __init__(self, on_message: Callable[[bytes, Tuple[str, int]], Coroutine]):
@@ -23,28 +22,30 @@ class ListenProtocol(asyncio.DatagramProtocol):
     asyncio.create_task(self.on_message(data, addr))
     asyncio.create_task(self.on_message(data, addr))
 
 
 
 
-class GRPCDiscovery(Discovery):
+class UDPDiscovery(Discovery):
   def __init__(
   def __init__(
     self,
     self,
     node_id: str,
     node_id: str,
     node_port: int,
     node_port: int,
     listen_port: int,
     listen_port: int,
-    broadcast_port: int = None,
+    broadcast_port: int,
+    create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
     broadcast_interval: int = 1,
     broadcast_interval: int = 1,
-    device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
     discovery_timeout: int = 30,
     discovery_timeout: int = 30,
+    device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
   ):
   ):
     self.node_id = node_id
     self.node_id = node_id
     self.node_port = node_port
     self.node_port = node_port
-    self.device_capabilities = device_capabilities
     self.listen_port = listen_port
     self.listen_port = listen_port
-    self.broadcast_port = broadcast_port if broadcast_port is not None else listen_port
+    self.broadcast_port = broadcast_port
+    self.create_peer_handle = create_peer_handle
     self.broadcast_interval = broadcast_interval
     self.broadcast_interval = broadcast_interval
-    self.known_peers: Dict[str, Tuple[GRPCPeerHandle, float, float]] = {}
+    self.discovery_timeout = discovery_timeout
+    self.device_capabilities = device_capabilities
+    self.known_peers: Dict[str, Tuple[PeerHandle, float, float]] = {}
     self.broadcast_task = None
     self.broadcast_task = None
     self.listen_task = None
     self.listen_task = None
     self.cleanup_task = None
     self.cleanup_task = None
-    self.discovery_timeout = discovery_timeout
 
 
   async def start(self):
   async def start(self):
     self.device_capabilities = device_capabilities()
     self.device_capabilities = device_capabilities()
@@ -53,68 +54,45 @@ class GRPCDiscovery(Discovery):
     self.cleanup_task = asyncio.create_task(self.task_cleanup_peers())
     self.cleanup_task = asyncio.create_task(self.task_cleanup_peers())
 
 
   async def stop(self):
   async def stop(self):
-    if self.broadcast_task:
-      self.broadcast_task.cancel()
-    if self.listen_task:
-      self.listen_task.cancel()
-    if self.cleanup_task:
-      self.cleanup_task.cancel()
+    if self.broadcast_task: self.broadcast_task.cancel()
+    if self.listen_task: self.listen_task.cancel()
+    if self.cleanup_task: self.cleanup_task.cancel()
     if self.broadcast_task or self.listen_task or self.cleanup_task:
     if self.broadcast_task or self.listen_task or self.cleanup_task:
       await asyncio.gather(self.broadcast_task, self.listen_task, self.cleanup_task, return_exceptions=True)
       await asyncio.gather(self.broadcast_task, self.listen_task, self.cleanup_task, return_exceptions=True)
 
 
   async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
   async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
-    if DEBUG_DISCOVERY >= 2:
-      print("Starting peer discovery process...")
-
     if wait_for_peers > 0:
     if wait_for_peers > 0:
-      while len(self.known_peers) == 0:
-        if DEBUG_DISCOVERY >= 2:
-          print("No peers discovered yet, retrying in 1 second...")
-        await asyncio.sleep(1)  # Keep trying to find peers
-      if DEBUG_DISCOVERY >= 2:
-        print(f"Discovered first peer: {next(iter(self.known_peers.values()))}")
-
-    grace_period = 5  # seconds
-    while True:
-      initial_peer_count = len(self.known_peers)
-      if DEBUG_DISCOVERY >= 2:
-        print(f"Current number of known peers: {initial_peer_count}. Waiting {grace_period} seconds to discover more...")
-      if len(self.known_peers) == initial_peer_count:
-        if wait_for_peers > 0:
-          await asyncio.sleep(grace_period)
-          if DEBUG_DISCOVERY >= 2:
-            print(f"Waiting additional {wait_for_peers} seconds for more peers.")
-          wait_for_peers = 0
-        else:
-          if DEBUG_DISCOVERY >= 2:
-            print("No new peers discovered in the last grace period. Ending discovery process.")
-          break  # No new peers found in the grace period, we are done
-
+      while len(self.known_peers) < wait_for_peers:
+        if DEBUG_DISCOVERY >= 2: print(f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers...")
+        await asyncio.sleep(0.1)
     return [peer_handle for peer_handle, _, _ in self.known_peers.values()]
     return [peer_handle for peer_handle, _, _ in self.known_peers.values()]
 
 
   async def task_broadcast_presence(self):
   async def task_broadcast_presence(self):
-    transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(lambda: asyncio.DatagramProtocol(), local_addr=("0.0.0.0", 0), family=socket.AF_INET)
-    sock = transport.get_extra_info("socket")
-    sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
-
-    message = json.dumps({
-      "type": "discovery",
-      "node_id": self.node_id,
-      "grpc_port": self.node_port,
-      "device_capabilities": self.device_capabilities.to_dict(),
-    }).encode("utf-8")
-
     while True:
     while True:
       try:
       try:
-        if DEBUG_DISCOVERY >= 3:
-          print(f"Broadcast presence: {message}")
+        message = json.dumps({
+          "type": "discovery",
+          "node_id": self.node_id,
+          "grpc_port": self.node_port,
+          "device_capabilities": self.device_capabilities.to_dict(),
+        }).encode("utf-8")
+        if DEBUG_DISCOVERY >= 3: print(f"Broadcast presence: {message}")
+
+        transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(lambda: asyncio.DatagramProtocol(), local_addr=("0.0.0.0", 0), family=socket.AF_INET)
+        sock = transport.get_extra_info("socket")
+        sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
         transport.sendto(message, ("<broadcast>", self.broadcast_port))
         transport.sendto(message, ("<broadcast>", self.broadcast_port))
-        await asyncio.sleep(self.broadcast_interval)
       except Exception as e:
       except Exception as e:
         print(f"Error in broadcast presence: {e}")
         print(f"Error in broadcast presence: {e}")
-        import traceback
-
         print(traceback.format_exc())
         print(traceback.format_exc())
+      finally:
+        if transport:
+          try:
+            transport.close()
+          except:
+            if DEBUG_DISCOVERY >= 2: print(f"Error closing transport: {e}")
+            if DEBUG_DISCOVERY >= 2: traceback.print_exc()
+        await asyncio.sleep(self.broadcast_interval)
 
 
   async def on_listen_message(self, data, addr):
   async def on_listen_message(self, data, addr):
     if not data:
     if not data:
@@ -124,40 +102,35 @@ class GRPCDiscovery(Discovery):
 
 
     # Check if the decoded data starts with a valid JSON character
     # Check if the decoded data starts with a valid JSON character
     if not (decoded_data.strip() and decoded_data.strip()[0] in "{["):
     if not (decoded_data.strip() and decoded_data.strip()[0] in "{["):
-      if DEBUG_DISCOVERY >= 2:
-        print(f"Received invalid JSON data from {addr}: {decoded_data[:100]}")
+      if DEBUG_DISCOVERY >= 2: print(f"Received invalid JSON data from {addr}: {decoded_data[:100]}")
       return
       return
 
 
     try:
     try:
       decoder = json.JSONDecoder(strict=False)
       decoder = json.JSONDecoder(strict=False)
       message = decoder.decode(decoded_data)
       message = decoder.decode(decoded_data)
     except json.JSONDecodeError as e:
     except json.JSONDecodeError as e:
-      if DEBUG_DISCOVERY >= 2:
-        print(f"Error decoding JSON data from {addr}: {e}")
+      if DEBUG_DISCOVERY >= 2: print(f"Error decoding JSON data from {addr}: {e}")
       return
       return
 
 
-    if DEBUG_DISCOVERY >= 2:
-      print(f"received from peer {addr}: {message}")
+    if DEBUG_DISCOVERY >= 2: print(f"received from peer {addr}: {message}")
 
 
     if message["type"] == "discovery" and message["node_id"] != self.node_id:
     if message["type"] == "discovery" and message["node_id"] != self.node_id:
       peer_id = message["node_id"]
       peer_id = message["node_id"]
       peer_host = addr[0]
       peer_host = addr[0]
       peer_port = message["grpc_port"]
       peer_port = message["grpc_port"]
       device_capabilities = DeviceCapabilities(**message["device_capabilities"])
       device_capabilities = DeviceCapabilities(**message["device_capabilities"])
-      if peer_id not in self.known_peers:
+      if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
+        if DEBUG >= 1: print(f"Adding {peer_id=} at {peer_host}:{peer_port}. Replace existing peer_id: {peer_id in self.known_peers}")
         self.known_peers[peer_id] = (
         self.known_peers[peer_id] = (
-          GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities),
+          self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities),
           time.time(),
           time.time(),
           time.time(),
           time.time(),
         )
         )
-        if DEBUG_DISCOVERY >= 2:
-          print(f"Discovered new peer {peer_id} at {peer_host}:{peer_port}")
       self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time())
       self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time())
 
 
   async def task_listen_for_peers(self):
   async def task_listen_for_peers(self):
     await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message), local_addr=("0.0.0.0", self.listen_port))
     await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message), local_addr=("0.0.0.0", self.listen_port))
-    if DEBUG_DISCOVERY >= 2:
-      print("Started listen task")
+    if DEBUG_DISCOVERY >= 2: print("Started listen task")
 
 
   async def task_cleanup_peers(self):
   async def task_cleanup_peers(self):
     while True:
     while True:
@@ -167,22 +140,12 @@ class GRPCDiscovery(Discovery):
           peer_handle.id() for peer_handle, connected_at, last_seen in self.known_peers.values()
           peer_handle.id() for peer_handle, connected_at, last_seen in self.known_peers.values()
           if (not await peer_handle.is_connected() and current_time - connected_at > self.discovery_timeout) or current_time - last_seen > self.discovery_timeout
           if (not await peer_handle.is_connected() and current_time - connected_at > self.discovery_timeout) or current_time - last_seen > self.discovery_timeout
         ]
         ]
-        if DEBUG_DISCOVERY >= 2:
-          print(
-            "Peer statuses:",
-            {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}"
-             for peer_handle, connected_at, last_seen in self.known_peers.values()},
-          )
-        if DEBUG_DISCOVERY >= 2 and len(peers_to_remove) > 0:
-          print(f"Cleaning up peers: {peers_to_remove}")
+        if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}" for peer_handle, connected_at, last_seen in self.known_peers.values()})
         for peer_id in peers_to_remove:
         for peer_id in peers_to_remove:
-          if peer_id in self.known_peers:
-            del self.known_peers[peer_id]
-          if DEBUG_DISCOVERY >= 2:
-            print(f"Removed peer {peer_id} due to inactivity.")
-        await asyncio.sleep(self.broadcast_interval)
+          if peer_id in self.known_peers: del self.known_peers[peer_id]
+          if DEBUG_DISCOVERY >= 2: print(f"Removed peer {peer_id} due to inactivity.")
       except Exception as e:
       except Exception as e:
         print(f"Error in cleanup peers: {e}")
         print(f"Error in cleanup peers: {e}")
-        import traceback
-
         print(traceback.format_exc())
         print(traceback.format_exc())
+      finally:
+        await asyncio.sleep(self.broadcast_interval)

+ 67 - 15
exo/orchestration/standard_node.py

@@ -50,7 +50,7 @@ class StandardNode(Node):
     await self.update_peers(wait_for_peers)
     await self.update_peers(wait_for_peers)
     await self.collect_topology()
     await self.collect_topology()
     if DEBUG >= 2: print(f"Collected topology: {self.topology}")
     if DEBUG >= 2: print(f"Collected topology: {self.topology}")
-    asyncio.create_task(self.periodic_topology_collection(5))
+    asyncio.create_task(self.periodic_topology_collection(1.0))
 
 
   async def stop(self) -> None:
   async def stop(self) -> None:
     await self.discovery.stop()
     await self.discovery.stop()
@@ -277,23 +277,75 @@ class StandardNode(Node):
       raise ValueError(f"No current partition found for node: {self.id}")
       raise ValueError(f"No current partition found for node: {self.id}")
     return shards[current_partition_index]
     return shards[current_partition_index]
 
 
-  async def update_peers(self, wait_for_peers: int = 0) -> None:
-    self.peers = await self.discovery.discover_peers(wait_for_peers)
-    for peer in self.peers:
-      is_connected = await peer.is_connected()
-      if DEBUG >= 2 and is_connected:
-        print(f"Already connected to {peer.id()}: {is_connected}")
-      if not is_connected:
-        if DEBUG >= 2: print(f"Connecting to {peer.id()}...")
-        await peer.connect()
-        if DEBUG >= 1: print(f"Connected to peer {peer.device_capabilities()} ({peer.id()=})")
+  async def update_peers(self, wait_for_peers: int = 0) -> bool:
+    next_peers = await self.discovery.discover_peers(wait_for_peers)
+    current_peer_ids = {peer.id() for peer in self.peers}
+    next_peer_ids = {peer.id() for peer in next_peers}
+    peers_added = [peer for peer in next_peers if peer.id() not in current_peer_ids]
+    peers_removed = [peer for peer in self.peers if peer.id() not in next_peer_ids]
+    peers_updated = [
+      peer for peer in next_peers
+      if peer.id() in current_peer_ids and any(p.addr() != peer.addr() for p in self.peers if p.id() == peer.id())
+    ]
+    peers_unchanged = [
+      peer for peer in next_peers
+      if peer.id() in current_peer_ids and all(p.addr() == peer.addr() for p in self.peers if p.id() == peer.id())
+    ]
+    peers_to_disconnect = [peer for peer in peers_removed if await peer.is_connected()]
+    peers_to_connect = [peer for peer in peers_added + peers_updated + peers_unchanged if not await peer.is_connected()]
+
+    def _pretty(peers: List[PeerHandle]) -> List[str]:
+      return [f"{peer.id()}@{peer.addr()}" for peer in peers]
+    if DEBUG >= 2: print(f"update_peers: added={peers_added} removed={peers_removed} updated={peers_updated} unchanged={peers_unchanged} to_disconnect={peers_to_disconnect} to_connect={peers_to_connect}")
+
+    async def disconnect_with_timeout(peer, timeout=5):
+      try:
+        await asyncio.wait_for(peer.disconnect(), timeout)
+        return True
+      except Exception as e:
+        print(f"Error disconnecting peer {peer.id()}@{peer.addr()}: {e}")
+        traceback.print_exc()
+        return False
+
+    async def connect_with_timeout(peer, timeout=5):
+      try:
+        await asyncio.wait_for(peer.connect(), timeout)
+        return True
+      except Exception as e:
+        print(f"Error connecting peer {peer.id()}@{peer.addr()}: {e}")
+        traceback.print_exc()
+        return False
+
+    disconnect_results = await asyncio.gather(
+      *(disconnect_with_timeout(peer) for peer in peers_to_disconnect),
+      return_exceptions=True
+    )
+    connect_results = await asyncio.gather(
+      *(connect_with_timeout(peer) for peer in peers_to_connect),
+      return_exceptions=True
+    )
+
+    successful_disconnects = [peer for peer, result in zip(peers_to_disconnect, disconnect_results) if result is True]
+    failed_disconnects = [peer for peer, result in zip(peers_to_disconnect, disconnect_results) if result is False]
+    successful_connects = [peer for peer, result in zip(peers_to_connect, connect_results) if result is True]
+    failed_connects = [peer for peer, result in zip(peers_to_connect, connect_results) if result is False]
+    if DEBUG >= 1:
+      if successful_disconnects: print(f"Successfully disconnected peers: {_pretty(successful_disconnects)}")
+      if failed_disconnects: print(f"Failed to disconnect peers: {_pretty(failed_disconnects)}")
+      if successful_connects: print(f"Successfully connected peers: {_pretty(successful_connects)}")
+      if failed_connects: print(f"Failed to connect peers: {_pretty(failed_connects)}")
+
+    self.peers = next_peers
+    return len(peers_to_connect) > 0 or len(peers_to_disconnect) > 0
 
 
   async def periodic_topology_collection(self, interval: int):
   async def periodic_topology_collection(self, interval: int):
     while True:
     while True:
       await asyncio.sleep(interval)
       await asyncio.sleep(interval)
       try:
       try:
-        await self.update_peers()
-        await self.collect_topology()
+        did_peers_change = await self.update_peers()
+        if DEBUG >= 2: print(f"{did_peers_change=}")
+        if did_peers_change:
+          await self.collect_topology()
       except Exception as e:
       except Exception as e:
         print(f"Error collecting topology: {e}")
         print(f"Error collecting topology: {e}")
         traceback.print_exc()
         traceback.print_exc()
@@ -310,7 +362,7 @@ class StandardNode(Node):
     if DEBUG >= 2: print(f"Collecting topology {max_depth=} {visited=}")
     if DEBUG >= 2: print(f"Collecting topology {max_depth=} {visited=}")
 
 
     prev_visited = visited.copy()
     prev_visited = visited.copy()
-    # TODO: should we add our own peer id here?
+    visited.add(self.id)
     visited.update(p.id() for p in self.peers)
     visited.update(p.id() for p in self.peers)
 
 
     for peer in self.peers:
     for peer in self.peers:
@@ -325,7 +377,7 @@ class StandardNode(Node):
         continue
         continue
 
 
       try:
       try:
-        other_topology = await peer.collect_topology(visited, max_depth=max_depth - 1)
+        other_topology = await asyncio.wait_for(peer.collect_topology(visited, max_depth=max_depth - 1), timeout=5.0)
         if DEBUG >= 2: print(f"Collected topology from: {peer.id()}: {other_topology}")
         if DEBUG >= 2: print(f"Collected topology from: {peer.id()}: {other_topology}")
         self.topology.merge(other_topology)
         self.topology.merge(other_topology)
       except Exception as e:
       except Exception as e:

+ 5 - 4
main.py

@@ -7,7 +7,8 @@ import traceback
 import uuid
 import uuid
 from exo.orchestration.standard_node import StandardNode
 from exo.orchestration.standard_node import StandardNode
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.grpc.grpc_server import GRPCServer
-from exo.networking.grpc.grpc_discovery import GRPCDiscovery
+from exo.networking.udp_discovery import UDPDiscovery
+from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 from exo.api import ChatGPTAPI
 from exo.api import ChatGPTAPI
 from exo.download.shard_download import ShardDownloader, RepoProgressEvent
 from exo.download.shard_download import ShardDownloader, RepoProgressEvent
@@ -33,7 +34,7 @@ parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast
 parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery timeout in seconds")
 parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery timeout in seconds")
 parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
 parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
 parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
 parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
-parser.add_argument("--chatgpt-api-response-timeout-secs", type=int, default=90, help="ChatGPT API response timeout in seconds")
+parser.add_argument("--chatgpt-api-response-timeout", type=int, default=90, help="ChatGPT API response timeout in seconds")
 parser.add_argument("--max-generate-tokens", type=int, default=1024, help="Max tokens to generate in each request")
 parser.add_argument("--max-generate-tokens", type=int, default=1024, help="Max tokens to generate in each request")
 parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use")
 parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use")
 parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
 parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
@@ -66,7 +67,7 @@ if DEBUG >= 0:
   for chatgpt_api_endpoint in chatgpt_api_endpoints:
   for chatgpt_api_endpoint in chatgpt_api_endpoints:
     print(f" - {terminal_link(chatgpt_api_endpoint)}")
     print(f" - {terminal_link(chatgpt_api_endpoint)}")
 
 
-discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, discovery_timeout=args.discovery_timeout)
+discovery = UDPDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities), discovery_timeout=args.discovery_timeout)
 topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None
 topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None
 node = StandardNode(
 node = StandardNode(
   args.node_id,
   args.node_id,
@@ -82,7 +83,7 @@ node.server = server
 api = ChatGPTAPI(
 api = ChatGPTAPI(
   node,
   node,
   inference_engine.__class__.__name__,
   inference_engine.__class__.__name__,
-  response_timeout_secs=args.chatgpt_api_response_timeout_secs,
+  response_timeout=args.chatgpt_api_response_timeout,
   on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None
   on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None
 )
 )
 node.on_token.register("update_topology_viz").on_next(
 node.on_token.register("update_topology_viz").on_next(

+ 21 - 0
test/reconnect.sh

@@ -0,0 +1,21 @@
+#!/bin/bash
+
+echo "Starting node 1"
+DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 --chatgpt-api-response-timeout 900 > output1.log 2>&1 &
+PID1=$!
+echo "Started node 1 PID: $PID1"
+echo "Starting node 2"
+DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 --chatgpt-api-response-timeout 900 > output2.log 2>&1 &
+PID2=$!
+echo "Started node 2 PID: $PID2"
+sleep 5
+kill $PID2
+sleep 5
+echo "Starting node 2 again..."
+DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 --chatgpt-api-response-timeout 900 > output3.log 2>&1 &
+PID2=$!
+sleep 5
+echo "Killing nodes and ending test..."
+kill $PID1
+kill $PID2
+echo "Test complete."