浏览代码

make all I/O ops in manual_discovery.py run inside a ThreadPoolExecutor

Ian Paul 5 月之前
父节点
当前提交
b066c944f3
共有 1 个文件被更改,包括 40 次插入20 次删除
  1. 40 20
      exo/networking/manual/manual_discovery.py

+ 40 - 20
exo/networking/manual/manual_discovery.py

@@ -1,8 +1,9 @@
 import os
 import asyncio
-from exo.networking.discovery import Discovery
 from typing import Dict, List, Callable, Optional
+from concurrent.futures import ThreadPoolExecutor
 
+from exo.networking.discovery import Discovery
 from exo.topology.device_capabilities import DeviceCapabilities
 from exo.networking.manual.network_topology_config import NetworkTopology, PeerConfig
 from exo.helpers import DEBUG_DISCOVERY
@@ -26,6 +27,8 @@ class ManualDiscovery(Discovery):
 
     self._cached_peers: Dict[str, PeerConfig] = {}
     self._last_modified_time: Optional[float] = None
+    self._file_executor = ThreadPoolExecutor(max_workers=1)
+    self._lock = asyncio.Lock()
 
   async def start(self) -> None:
     self.listen_task = asyncio.create_task(self.task_find_peers_from_config())
@@ -34,6 +37,7 @@ class ManualDiscovery(Discovery):
   async def stop(self) -> None:
     if self.listen_task: self.listen_task.cancel()
     if self.cleanup_task: self.cleanup_task.cancel()
+    self._file_executor.shutdown(wait=True)
 
   async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
     if wait_for_peers > 0:
@@ -46,7 +50,8 @@ class ManualDiscovery(Discovery):
   async def task_find_peers_from_config(self):
     if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...")
     while True:
-      for peer_id, peer_config in self._get_peers().items():
+      peers_from_config = await self._get_peers()
+      for peer_id, peer_config in peers_from_config.items():
         try:
           if DEBUG_DISCOVERY >= 2: print(f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}")
           peer = self.known_peers.get(peer_id)
@@ -72,7 +77,7 @@ class ManualDiscovery(Discovery):
   async def task_clean_up_peers_from_config(self):
     if DEBUG_DISCOVERY >= 2: print("Starting task to clean up peers from config...")
     while True:
-      peers_from_config = self._get_peers()
+      peers_from_config = await self._get_peers()
       if peers_from_config:
         peers_to_remove = [peer for peer in self.known_peers.keys() if peer not in peers_from_config]
 
@@ -83,30 +88,45 @@ class ManualDiscovery(Discovery):
 
       await asyncio.sleep(1.0)
 
-  def _get_peers(self):
+  async def _get_peers(self):
     try:
-      current_mtime = os.path.getmtime(self.network_config_path)
-
-      if self._cached_peers is not None and self._last_modified_time is not None and current_mtime <= self._last_modified_time:
-        return self._cached_peers
+      async with self._lock:
+        loop = asyncio.get_running_loop()
+        current_mtime = await loop.run_in_executor(
+          self._file_executor,
+          os.path.getmtime,
+          self.network_config_path
+        )
 
-      topology = NetworkTopology.from_path(self.network_config_path)
+        if (self._cached_peers is not None and
+          self._last_modified_time is not None and
+          current_mtime <= self._last_modified_time):
+          return self._cached_peers
 
-      if self.node_id not in topology.peers:
-        raise ValueError(
-          f"Node ID {self.node_id} not found in network config file "
-          f"{self.network_config_path}. Please run with `node_id` set to "
-          f"one of the keys in the config file: {[k for k, _ in topology.peers]}"
+        topology = await loop.run_in_executor(
+          self._file_executor,
+          NetworkTopology.from_path,
+          self.network_config_path
         )
 
-      peers_in_network: Dict[str, PeerConfig] = topology.peers
-      peers_in_network.pop(self.node_id)
+        if self.node_id not in topology.peers:
+          raise ValueError(
+            f"Node ID {self.node_id} not found in network config file "
+            f"{self.network_config_path}. Please run with `node_id` set to "
+            f"one of the keys in the config file: {[k for k, _ in topology.peers]}"
+          )
+
+        peers_in_network: Dict[str, PeerConfig] = topology.peers
+        peers_in_network.pop(self.node_id)
 
-      self._cached_peers = peers_in_network
-      self._last_modified_time = current_mtime
+        self._cached_peers = peers_in_network
+        self._last_modified_time = current_mtime
 
-      return peers_in_network
+        return peers_in_network
 
     except Exception as e:
-      if DEBUG_DISCOVERY >= 2: print(f"Error when loading network config file from {self.network_config_path}. Please update the config file in order to successfully discover peers. Exception: {e}")
+      if DEBUG_DISCOVERY >= 2:
+        print(f"Error when loading network config file from {self.network_config_path}. "
+          f"Please update the config file in order to successfully discover peers. "
+          f"Exception: {e}")
       return self._cached_peers