grpc_discovery.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import asyncio
  2. import json
  3. import socket
  4. import time
  5. from typing import List, Dict
  6. from ..discovery import Discovery
  7. from ..peer_handle import PeerHandle
  8. from .grpc_peer_handle import GRPCPeerHandle
  9. from topology.device_capabilities import DeviceCapabilities, device_capabilities
  10. class GRPCDiscovery(Discovery):
  11. def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_port: int = None, broadcast_interval: int = 1, device_capabilities=None):
  12. self.node_id = node_id
  13. self.node_port = node_port
  14. self.device_capabilities = device_capabilities
  15. self.listen_port = listen_port
  16. self.broadcast_port = broadcast_port if broadcast_port is not None else listen_port
  17. self.broadcast_interval = broadcast_interval
  18. self.known_peers: Dict[str, GRPCPeerHandle] = {}
  19. self.peer_last_seen: Dict[str, float] = {}
  20. self.broadcast_task = None
  21. self.listen_task = None
  22. self.cleanup_task = None
  23. async def start(self):
  24. self.broadcast_task = asyncio.create_task(self._broadcast_presence())
  25. self.listen_task = asyncio.create_task(self._listen_for_peers())
  26. self.cleanup_task = asyncio.create_task(self._cleanup_peers())
  27. async def stop(self):
  28. if self.broadcast_task:
  29. self.broadcast_task.cancel()
  30. if self.listen_task:
  31. self.listen_task.cancel()
  32. if self.cleanup_task:
  33. self.cleanup_task.cancel()
  34. if self.broadcast_task or self.listen_task or self.cleanup_task:
  35. await asyncio.gather(self.broadcast_task, self.listen_task, self.cleanup_task, return_exceptions=True)
  36. async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
  37. print("Starting peer discovery process...")
  38. if wait_for_peers > 0:
  39. while not self.known_peers:
  40. print("No peers discovered yet, retrying in 1 second...")
  41. await asyncio.sleep(1) # Keep trying to find peers
  42. print(f"Discovered first peer: {next(iter(self.known_peers.values()))}")
  43. grace_period = 5 # seconds
  44. while True:
  45. initial_peer_count = len(self.known_peers)
  46. print(f"Current number of known peers: {initial_peer_count}. Waiting {grace_period} seconds to discover more...")
  47. await asyncio.sleep(grace_period)
  48. if len(self.known_peers) == initial_peer_count:
  49. if wait_for_peers > 0:
  50. print(f"Waiting additional {wait_for_peers} seconds for more peers.")
  51. await asyncio.sleep(wait_for_peers)
  52. wait_for_peers = 0
  53. else:
  54. print("No new peers discovered in the last grace period. Ending discovery process.")
  55. break # No new peers found in the grace period, we are done
  56. return list(self.known_peers.values())
  57. async def _broadcast_presence(self):
  58. if not self.device_capabilities:
  59. self.device_capabilities = device_capabilities()
  60. sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
  61. sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
  62. sock.settimeout(0.5)
  63. message = json.dumps({
  64. "type": "discovery",
  65. "node_id": self.node_id,
  66. "grpc_port": self.node_port,
  67. "device_capabilities": {
  68. "model": self.device_capabilities.model,
  69. "chip": self.device_capabilities.chip,
  70. "memory": self.device_capabilities.memory
  71. }
  72. }).encode('utf-8')
  73. while True:
  74. sock.sendto(message, ('<broadcast>', self.broadcast_port))
  75. await asyncio.sleep(self.broadcast_interval)
  76. async def _listen_for_peers(self):
  77. sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  78. sock.bind(('', self.listen_port))
  79. sock.setblocking(False)
  80. while True:
  81. try:
  82. data, addr = await asyncio.get_event_loop().sock_recvfrom(sock, 1024)
  83. message = json.loads(data.decode('utf-8'))
  84. print(f"received from peer {addr}: {message}")
  85. if message['type'] == 'discovery' and message['node_id'] != self.node_id:
  86. peer_id = message['node_id']
  87. peer_host = addr[0]
  88. peer_port = message['grpc_port']
  89. device_capabilities = DeviceCapabilities(**message['device_capabilities'])
  90. if peer_id not in self.known_peers:
  91. self.known_peers[peer_id] = GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
  92. self.peer_last_seen[peer_id] = time.time()
  93. except Exception as e:
  94. print(f"Error in peer discovery: {e}")
  95. await asyncio.sleep(self.broadcast_interval / 2)
  96. async def _cleanup_peers(self):
  97. while True:
  98. current_time = time.time()
  99. timeout = 15 * self.broadcast_interval
  100. peers_to_remove = [peer_id for peer_id, last_seen in self.peer_last_seen.items() if current_time - last_seen > timeout]
  101. for peer_id in peers_to_remove:
  102. del self.known_peers[peer_id]
  103. del self.peer_last_seen[peer_id]
  104. print(f"Removed peer {peer_id} due to inactivity.")
  105. await asyncio.sleep(self.broadcast_interval)