manual_discovery.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import os
  2. import asyncio
  3. from typing import Dict, List, Callable, Optional
  4. from concurrent.futures import ThreadPoolExecutor
  5. from exo.networking.discovery import Discovery
  6. from exo.topology.device_capabilities import DeviceCapabilities
  7. from exo.networking.manual.network_topology_config import NetworkTopology, PeerConfig
  8. from exo.helpers import DEBUG_DISCOVERY
  9. from exo.networking.peer_handle import PeerHandle
  10. class ManualDiscovery(Discovery):
  11. def __init__(
  12. self,
  13. network_config_path: str,
  14. node_id: str,
  15. create_peer_handle: Callable[[str, str, str, DeviceCapabilities], PeerHandle],
  16. ):
  17. self.network_config_path = network_config_path
  18. self.node_id = node_id
  19. self.create_peer_handle = create_peer_handle
  20. self.listen_task = None
  21. self.known_peers: Dict[str, PeerHandle] = {}
  22. self._cached_peers: Dict[str, PeerConfig] = {}
  23. self._last_modified_time: Optional[float] = None
  24. self._file_executor = ThreadPoolExecutor(max_workers=1)
  25. async def start(self) -> None:
  26. self.listen_task = asyncio.create_task(self.task_find_peers_from_config())
  27. async def stop(self) -> None:
  28. if self.listen_task: self.listen_task.cancel()
  29. self._file_executor.shutdown(wait=True)
  30. async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
  31. if wait_for_peers > 0:
  32. while len(self.known_peers) < wait_for_peers:
  33. if DEBUG_DISCOVERY >= 2: print(f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers...")
  34. await asyncio.sleep(0.1)
  35. if DEBUG_DISCOVERY >= 2: print(f"Discovered peers: {[peer.id() for peer in self.known_peers.values()]}")
  36. return list(self.known_peers.values())
  37. async def task_find_peers_from_config(self):
  38. if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...")
  39. while True:
  40. peers_from_config = await self._get_peers()
  41. new_known_peers = {}
  42. for peer_id, peer_config in peers_from_config.items():
  43. try:
  44. if DEBUG_DISCOVERY >= 2: print(f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}")
  45. peer = self.known_peers.get(peer_id)
  46. if not peer:
  47. if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} not found in known peers. Adding.")
  48. peer = self.create_peer_handle(peer_id, f"{peer_config.address}:{peer_config.port}", "MAN", peer_config.device_capabilities)
  49. is_healthy = await peer.health_check()
  50. if is_healthy:
  51. if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")
  52. new_known_peers[peer_id] = peer
  53. elif DEBUG_DISCOVERY >= 2:
  54. print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy. Removing.")
  55. except Exception as e:
  56. if DEBUG_DISCOVERY >= 2: print(f"Exception occurred when attempting to add {peer_id=}: {e}")
  57. await asyncio.sleep(5.0)
  58. if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}")
  59. async def _get_peers(self):
  60. try:
  61. loop = asyncio.get_running_loop()
  62. current_mtime = await loop.run_in_executor(self._file_executor, os.path.getmtime, self.network_config_path)
  63. if (self._cached_peers is not None and self._last_modified_time is not None and current_mtime <= self._last_modified_time):
  64. return self._cached_peers
  65. topology = await loop.run_in_executor(self._file_executor, NetworkTopology.from_path, self.network_config_path)
  66. if self.node_id not in topology.peers:
  67. raise ValueError(
  68. f"Node ID {self.node_id} not found in network config file "
  69. f"{self.network_config_path}. Please run with `node_id` set to "
  70. f"one of the keys in the config file: {[k for k, _ in topology.peers]}"
  71. )
  72. peers_in_network = topology.peers
  73. peers_in_network.pop(self.node_id)
  74. self._cached_peers = peers_in_network
  75. self._last_modified_time = current_mtime
  76. return peers_in_network
  77. except Exception as e:
  78. if DEBUG_DISCOVERY >= 2:
  79. print(f"Error when loading network config file from {self.network_config_path}. "
  80. f"Please update the config file in order to successfully discover peers. "
  81. f"Exception: {e}")
  82. return self._cached_peers