udp_discovery.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. import asyncio
  2. import json
  3. import socket
  4. import time
  5. import traceback
  6. from typing import List, Dict, Callable, Tuple, Coroutine
  7. from exo.networking.discovery import Discovery
  8. from exo.networking.peer_handle import PeerHandle
  9. from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
  10. from exo.helpers import DEBUG, DEBUG_DISCOVERY, get_all_ip_addresses
  11. class ListenProtocol(asyncio.DatagramProtocol):
  12. def __init__(self, on_message: Callable[[bytes, Tuple[str, int]], Coroutine]):
  13. super().__init__()
  14. self.on_message = on_message
  15. self.loop = asyncio.get_event_loop()
  16. def connection_made(self, transport):
  17. self.transport = transport
  18. def datagram_received(self, data, addr):
  19. asyncio.create_task(self.on_message(data, addr))
  20. class BroadcastProtocol(asyncio.DatagramProtocol):
  21. def __init__(self, message: str, broadcast_port: int):
  22. self.message = message
  23. self.broadcast_port = broadcast_port
  24. def connection_made(self, transport):
  25. sock = transport.get_extra_info("socket")
  26. sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
  27. transport.sendto(self.message.encode("utf-8"), ("<broadcast>", self.broadcast_port))
  28. class UDPDiscovery(Discovery):
  29. def __init__(
  30. self,
  31. node_id: str,
  32. node_port: int,
  33. listen_port: int,
  34. broadcast_port: int,
  35. create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
  36. broadcast_interval: int = 1,
  37. discovery_timeout: int = 30,
  38. device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
  39. ):
  40. self.node_id = node_id
  41. self.node_port = node_port
  42. self.listen_port = listen_port
  43. self.broadcast_port = broadcast_port
  44. self.create_peer_handle = create_peer_handle
  45. self.broadcast_interval = broadcast_interval
  46. self.discovery_timeout = discovery_timeout
  47. self.device_capabilities = device_capabilities
  48. self.known_peers: Dict[str, Tuple[PeerHandle, float, float, int]] = {}
  49. self.broadcast_task = None
  50. self.listen_task = None
  51. self.cleanup_task = None
  52. async def start(self):
  53. self.device_capabilities = device_capabilities()
  54. self.broadcast_task = asyncio.create_task(self.task_broadcast_presence())
  55. self.listen_task = asyncio.create_task(self.task_listen_for_peers())
  56. self.cleanup_task = asyncio.create_task(self.task_cleanup_peers())
  57. async def stop(self):
  58. if self.broadcast_task: self.broadcast_task.cancel()
  59. if self.listen_task: self.listen_task.cancel()
  60. if self.cleanup_task: self.cleanup_task.cancel()
  61. if self.broadcast_task or self.listen_task or self.cleanup_task:
  62. await asyncio.gather(self.broadcast_task, self.listen_task, self.cleanup_task, return_exceptions=True)
  63. async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
  64. if wait_for_peers > 0:
  65. while len(self.known_peers) < wait_for_peers:
  66. if DEBUG_DISCOVERY >= 2: print(f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers...")
  67. await asyncio.sleep(0.1)
  68. return [peer_handle for peer_handle, _, _, _ in self.known_peers.values()]
  69. async def task_broadcast_presence(self):
  70. if DEBUG_DISCOVERY >= 2: print("Starting task_broadcast_presence...")
  71. while True:
  72. # Explicitly broadcasting on all assigned ips since broadcasting on `0.0.0.0` on MacOS does not broadcast over
  73. # the Thunderbolt bridge when other connection modalities exist such as WiFi or Ethernet
  74. for addr in get_all_ip_addresses():
  75. message = json.dumps({
  76. "type": "discovery",
  77. "node_id": self.node_id,
  78. "grpc_port": self.node_port,
  79. "device_capabilities": self.device_capabilities.to_dict(),
  80. "priority": 1, # For now, every interface has the same priority. We can make this better by prioriting interfaces based on bandwidth, latency, and jitter e.g. prioritise Thunderbolt over WiFi.
  81. })
  82. if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr}): {message}")
  83. transport = None
  84. try:
  85. transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
  86. lambda: BroadcastProtocol(message, self.broadcast_port),
  87. local_addr=(addr, 0),
  88. family=socket.AF_INET
  89. )
  90. if DEBUG_DISCOVERY >= 3:
  91. print(f"Broadcasting presence at ({addr})")
  92. except Exception as e:
  93. print(f"Error in broadcast presence ({addr}): {e}")
  94. finally:
  95. if transport:
  96. try:
  97. transport.close()
  98. except Exception as e:
  99. if DEBUG_DISCOVERY >= 2: print(f"Error closing transport: {e}")
  100. if DEBUG_DISCOVERY >= 2: traceback.print_exc()
  101. await asyncio.sleep(self.broadcast_interval)
  102. async def on_listen_message(self, data, addr):
  103. if not data:
  104. return
  105. decoded_data = data.decode("utf-8", errors="ignore")
  106. # Check if the decoded data starts with a valid JSON character
  107. if not (decoded_data.strip() and decoded_data.strip()[0] in "{["):
  108. if DEBUG_DISCOVERY >= 2: print(f"Received invalid JSON data from {addr}: {decoded_data[:100]}")
  109. return
  110. try:
  111. decoder = json.JSONDecoder(strict=False)
  112. message = decoder.decode(decoded_data)
  113. except json.JSONDecodeError as e:
  114. if DEBUG_DISCOVERY >= 2: print(f"Error decoding JSON data from {addr}: {e}")
  115. return
  116. if DEBUG_DISCOVERY >= 2: print(f"received from peer {addr}: {message}")
  117. if message["type"] == "discovery" and message["node_id"] != self.node_id:
  118. peer_id = message["node_id"]
  119. peer_host = addr[0]
  120. peer_port = message["grpc_port"]
  121. peer_prio = message["priority"]
  122. device_capabilities = DeviceCapabilities(**message["device_capabilities"])
  123. if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
  124. if peer_id in self.known_peers:
  125. existing_peer_prio = self.known_peers[peer_id][3]
  126. if existing_peer_prio >= peer_prio:
  127. if DEBUG >= 1: print(f"Ignoring peer {peer_id} at {peer_host}:{peer_port} with priority {peer_prio} because we already know about a peer with higher or equal priority: {existing_peer_prio}")
  128. return
  129. new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
  130. if not await new_peer_handle.health_check():
  131. if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} is not healthy. Skipping.")
  132. return
  133. if DEBUG >= 1: print(f"Adding {peer_id=} at {peer_host}:{peer_port}. Replace existing peer_id: {peer_id in self.known_peers}")
  134. self.known_peers[peer_id] = (new_peer_handle, time.time(), time.time(), peer_prio)
  135. else:
  136. if not await self.known_peers[peer_id][0].health_check():
  137. if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} is not healthy. Removing.")
  138. if peer_id in self.known_peers: del self.known_peers[peer_id]
  139. return
  140. if peer_id in self.known_peers: self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time(), peer_prio)
  141. async def task_listen_for_peers(self):
  142. await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message),
  143. local_addr=("0.0.0.0", self.listen_port))
  144. if DEBUG_DISCOVERY >= 2: print("Started listen task")
  145. async def task_cleanup_peers(self):
  146. while True:
  147. try:
  148. current_time = time.time()
  149. peers_to_remove = []
  150. peer_ids = list(self.known_peers.keys())
  151. results = await asyncio.gather(*[self.check_peer(peer_id, current_time) for peer_id in peer_ids], return_exceptions=True)
  152. for peer_id, should_remove in zip(peer_ids, results):
  153. if should_remove: peers_to_remove.append(peer_id)
  154. if DEBUG_DISCOVERY >= 2: print("Peer statuses:", { peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, health_check={await peer_handle.health_check()}, connected_at={connected_at}, last_seen={last_seen}, prio={prio}" for peer_handle, connected_at, last_seen, prio in self.known_peers.values() })
  155. for peer_id in peers_to_remove:
  156. if peer_id in self.known_peers:
  157. del self.known_peers[peer_id]
  158. if DEBUG_DISCOVERY >= 2: print(f"Removed peer {peer_id} due to inactivity or failed health check.")
  159. except Exception as e:
  160. print(f"Error in cleanup peers: {e}")
  161. print(traceback.format_exc())
  162. finally:
  163. await asyncio.sleep(self.broadcast_interval)
  164. async def check_peer(self, peer_id: str, current_time: float) -> bool:
  165. peer_handle, connected_at, last_seen, prio = self.known_peers.get(peer_id, (None, None, None, None))
  166. if peer_handle is None: return False
  167. try:
  168. is_connected = await peer_handle.is_connected()
  169. health_ok = await peer_handle.health_check()
  170. except Exception as e:
  171. if DEBUG_DISCOVERY >= 2: print(f"Error checking peer {peer_id}: {e}")
  172. return True
  173. should_remove = (
  174. (not is_connected and current_time - connected_at > self.discovery_timeout) or
  175. (current_time - last_seen > self.discovery_timeout) or
  176. (not health_ok)
  177. )
  178. return should_remove