Browse Source

ensure connected when health checking

Alex Cheema 9 months ago
parent
commit
5207edbd94
1 changed files with 13 additions and 2 deletions
  1. 13 2
      exo/networking/grpc/grpc_peer_handle.py

+ 13 - 2
exo/networking/grpc/grpc_peer_handle.py

@@ -10,6 +10,7 @@ from ..peer_handle import PeerHandle
 from exo.inference.shard import Shard
 from exo.topology.topology import Topology
 from exo.topology.device_capabilities import DeviceCapabilities
+from exo.helpers import DEBUG
 
 
 class GRPCPeerHandle(PeerHandle):
@@ -30,8 +31,9 @@ class GRPCPeerHandle(PeerHandle):
     return self._device_capabilities
 
   async def connect(self):
-    self.channel = grpc.aio.insecure_channel(self.address, options=[("grpc.max_metadata_size", 32*1024*1024)])
-    self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
+    if self.channel is None:
+      self.channel = grpc.aio.insecure_channel(self.address, options=[("grpc.max_metadata_size", 32*1024*1024)])
+      self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
     await self.channel.channel_ready()
 
   async def is_connected(self) -> bool:
@@ -43,14 +45,23 @@ class GRPCPeerHandle(PeerHandle):
     self.channel = None
     self.stub = None
 
+  async def _ensure_connected(self):
+    if not await self.is_connected():
+      await self.connect()
+
   async def health_check(self) -> bool:
     try:
+      await self._ensure_connected()
       request = node_service_pb2.HealthCheckRequest()
       response = await asyncio.wait_for(self.stub.HealthCheck(request), timeout=5)
       return response.is_healthy
     except asyncio.TimeoutError:
       return False
     except:
+      if DEBUG >= 4:
+        print(f"Health check failed for {self._id}@{self.address}.")
+        import traceback
+        traceback.print_exc()
       return False
 
   async def send_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]: