Browse Source

implement a health check for peers and discovery should only return healthy peers

Alex Cheema 10 months ago
parent
commit
de2f6d2e6e

+ 11 - 0
exo/networking/grpc/grpc_peer_handle.py

@@ -1,5 +1,6 @@
 import grpc
 import numpy as np
+import asyncio
 from typing import Optional, Tuple, List
 
 # These would be generated from the .proto file
@@ -43,6 +44,16 @@ class GRPCPeerHandle(PeerHandle):
     self.channel = None
     self.stub = None
 
+  async def health_check(self) -> bool:
+    try:
+      request = node_service_pb2.HealthCheckRequest()
+      response = await asyncio.wait_for(self.stub.HealthCheck(request), timeout=5)
+      return response.is_healthy
+    except asyncio.TimeoutError:
+      return False
+    except:
+      return False
+
   async def send_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
     request = node_service_pb2.PromptRequest(
       prompt=prompt,

+ 7 - 0
exo/networking/grpc/node_service.proto

@@ -9,6 +9,7 @@ service NodeService {
   rpc CollectTopology (CollectTopologyRequest) returns (Topology) {}
   rpc SendResult (SendResultRequest) returns (Empty) {}
   rpc SendOpaqueStatus (SendOpaqueStatusRequest) returns (Empty) {}
+  rpc HealthCheck (HealthCheckRequest) returns (HealthCheckResponse) {}
 }
 
 message Shard {
@@ -86,4 +87,10 @@ message SendOpaqueStatusRequest {
   string status = 2;
 }
 
+message HealthCheckRequest {}
+
+message HealthCheckResponse {
+  bool is_healthy = 1;
+}
+
 message Empty {}

File diff suppressed because it is too large
+ 0 - 1
exo/networking/grpc/node_service_pb2.py


+ 316 - 228
exo/networking/grpc/node_service_pb2_grpc.py

@@ -12,261 +12,349 @@ SCHEDULED_RELEASE_DATE = 'June 25, 2024'
 _version_not_supported = False
 
 try:
-  from grpc._utilities import first_version_is_lower
-  _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
+    from grpc._utilities import first_version_is_lower
+    _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
 except ImportError:
-  _version_not_supported = True
+    _version_not_supported = True
 
 if _version_not_supported:
-  warnings.warn(
-    f'The grpc package installed is at version {GRPC_VERSION},' + f' but the generated code in node_service_pb2_grpc.py depends on' + f' grpcio>={GRPC_GENERATED_VERSION}.' +
-    f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' +
-    f' This warning will become an error in {EXPECTED_ERROR_RELEASE},' + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.', RuntimeWarning
-  )
+    warnings.warn(
+        f'The grpc package installed is at version {GRPC_VERSION},'
+        + f' but the generated code in node_service_pb2_grpc.py depends on'
+        + f' grpcio>={GRPC_GENERATED_VERSION}.'
+        + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+        + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
+        + f' This warning will become an error in {EXPECTED_ERROR_RELEASE},'
+        + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.',
+        RuntimeWarning
+    )
 
 
 class NodeServiceStub(object):
-  """Missing associated documentation comment in .proto file."""
-  def __init__(self, channel):
-    """Constructor.
+    """Missing associated documentation comment in .proto file."""
+
+    def __init__(self, channel):
+        """Constructor.
 
         Args:
             channel: A grpc.Channel.
         """
-    self.SendPrompt = channel.unary_unary(
-      '/node_service.NodeService/SendPrompt',
-      request_serializer=node__service__pb2.PromptRequest.SerializeToString,
-      response_deserializer=node__service__pb2.Tensor.FromString,
-      _registered_method=True
-    )
-    self.SendTensor = channel.unary_unary(
-      '/node_service.NodeService/SendTensor',
-      request_serializer=node__service__pb2.TensorRequest.SerializeToString,
-      response_deserializer=node__service__pb2.Tensor.FromString,
-      _registered_method=True
-    )
-    self.GetInferenceResult = channel.unary_unary(
-      '/node_service.NodeService/GetInferenceResult',
-      request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
-      response_deserializer=node__service__pb2.InferenceResult.FromString,
-      _registered_method=True
-    )
-    self.CollectTopology = channel.unary_unary(
-      '/node_service.NodeService/CollectTopology',
-      request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
-      response_deserializer=node__service__pb2.Topology.FromString,
-      _registered_method=True
-    )
-    self.SendResult = channel.unary_unary(
-      '/node_service.NodeService/SendResult',
-      request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
-      response_deserializer=node__service__pb2.Empty.FromString,
-      _registered_method=True
-    )
-    self.SendOpaqueStatus = channel.unary_unary(
-      '/node_service.NodeService/SendOpaqueStatus',
-      request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
-      response_deserializer=node__service__pb2.Empty.FromString,
-      _registered_method=True
-    )
+        self.SendPrompt = channel.unary_unary(
+                '/node_service.NodeService/SendPrompt',
+                request_serializer=node__service__pb2.PromptRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Tensor.FromString,
+                _registered_method=True)
+        self.SendTensor = channel.unary_unary(
+                '/node_service.NodeService/SendTensor',
+                request_serializer=node__service__pb2.TensorRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Tensor.FromString,
+                _registered_method=True)
+        self.GetInferenceResult = channel.unary_unary(
+                '/node_service.NodeService/GetInferenceResult',
+                request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
+                response_deserializer=node__service__pb2.InferenceResult.FromString,
+                _registered_method=True)
+        self.CollectTopology = channel.unary_unary(
+                '/node_service.NodeService/CollectTopology',
+                request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Topology.FromString,
+                _registered_method=True)
+        self.SendResult = channel.unary_unary(
+                '/node_service.NodeService/SendResult',
+                request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Empty.FromString,
+                _registered_method=True)
+        self.SendOpaqueStatus = channel.unary_unary(
+                '/node_service.NodeService/SendOpaqueStatus',
+                request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Empty.FromString,
+                _registered_method=True)
+        self.HealthCheck = channel.unary_unary(
+                '/node_service.NodeService/HealthCheck',
+                request_serializer=node__service__pb2.HealthCheckRequest.SerializeToString,
+                response_deserializer=node__service__pb2.HealthCheckResponse.FromString,
+                _registered_method=True)
 
 
 class NodeServiceServicer(object):
-  """Missing associated documentation comment in .proto file."""
-  def SendPrompt(self, request, context):
     """Missing associated documentation comment in .proto file."""
-    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-    context.set_details('Method not implemented!')
-    raise NotImplementedError('Method not implemented!')
 
-  def SendTensor(self, request, context):
-    """Missing associated documentation comment in .proto file."""
-    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-    context.set_details('Method not implemented!')
-    raise NotImplementedError('Method not implemented!')
+    def SendPrompt(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
 
-  def GetInferenceResult(self, request, context):
-    """Missing associated documentation comment in .proto file."""
-    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-    context.set_details('Method not implemented!')
-    raise NotImplementedError('Method not implemented!')
+    def SendTensor(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
 
-  def CollectTopology(self, request, context):
-    """Missing associated documentation comment in .proto file."""
-    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-    context.set_details('Method not implemented!')
-    raise NotImplementedError('Method not implemented!')
+    def GetInferenceResult(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
 
-  def SendResult(self, request, context):
-    """Missing associated documentation comment in .proto file."""
-    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-    context.set_details('Method not implemented!')
-    raise NotImplementedError('Method not implemented!')
+    def CollectTopology(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
 
-  def SendOpaqueStatus(self, request, context):
-    """Missing associated documentation comment in .proto file."""
-    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-    context.set_details('Method not implemented!')
-    raise NotImplementedError('Method not implemented!')
+    def SendResult(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
+
+    def SendOpaqueStatus(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
+
+    def HealthCheck(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
 
 
 def add_NodeServiceServicer_to_server(servicer, server):
-  rpc_method_handlers = {
-    'SendPrompt':
-      grpc.unary_unary_rpc_method_handler(
-        servicer.SendPrompt,
-        request_deserializer=node__service__pb2.PromptRequest.FromString,
-        response_serializer=node__service__pb2.Tensor.SerializeToString,
-      ),
-    'SendTensor':
-      grpc.unary_unary_rpc_method_handler(
-        servicer.SendTensor,
-        request_deserializer=node__service__pb2.TensorRequest.FromString,
-        response_serializer=node__service__pb2.Tensor.SerializeToString,
-      ),
-    'GetInferenceResult':
-      grpc.unary_unary_rpc_method_handler(
-        servicer.GetInferenceResult,
-        request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
-        response_serializer=node__service__pb2.InferenceResult.SerializeToString,
-      ),
-    'CollectTopology':
-      grpc.unary_unary_rpc_method_handler(
-        servicer.CollectTopology,
-        request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
-        response_serializer=node__service__pb2.Topology.SerializeToString,
-      ),
-    'SendResult':
-      grpc.unary_unary_rpc_method_handler(
-        servicer.SendResult,
-        request_deserializer=node__service__pb2.SendResultRequest.FromString,
-        response_serializer=node__service__pb2.Empty.SerializeToString,
-      ),
-    'SendOpaqueStatus':
-      grpc.unary_unary_rpc_method_handler(
-        servicer.SendOpaqueStatus,
-        request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
-        response_serializer=node__service__pb2.Empty.SerializeToString,
-      ),
-  }
-  generic_handler = grpc.method_handlers_generic_handler('node_service.NodeService', rpc_method_handlers)
-  server.add_generic_rpc_handlers((generic_handler,))
-  server.add_registered_method_handlers('node_service.NodeService', rpc_method_handlers)
+    rpc_method_handlers = {
+            'SendPrompt': grpc.unary_unary_rpc_method_handler(
+                    servicer.SendPrompt,
+                    request_deserializer=node__service__pb2.PromptRequest.FromString,
+                    response_serializer=node__service__pb2.Tensor.SerializeToString,
+            ),
+            'SendTensor': grpc.unary_unary_rpc_method_handler(
+                    servicer.SendTensor,
+                    request_deserializer=node__service__pb2.TensorRequest.FromString,
+                    response_serializer=node__service__pb2.Tensor.SerializeToString,
+            ),
+            'GetInferenceResult': grpc.unary_unary_rpc_method_handler(
+                    servicer.GetInferenceResult,
+                    request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
+                    response_serializer=node__service__pb2.InferenceResult.SerializeToString,
+            ),
+            'CollectTopology': grpc.unary_unary_rpc_method_handler(
+                    servicer.CollectTopology,
+                    request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
+                    response_serializer=node__service__pb2.Topology.SerializeToString,
+            ),
+            'SendResult': grpc.unary_unary_rpc_method_handler(
+                    servicer.SendResult,
+                    request_deserializer=node__service__pb2.SendResultRequest.FromString,
+                    response_serializer=node__service__pb2.Empty.SerializeToString,
+            ),
+            'SendOpaqueStatus': grpc.unary_unary_rpc_method_handler(
+                    servicer.SendOpaqueStatus,
+                    request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
+                    response_serializer=node__service__pb2.Empty.SerializeToString,
+            ),
+            'HealthCheck': grpc.unary_unary_rpc_method_handler(
+                    servicer.HealthCheck,
+                    request_deserializer=node__service__pb2.HealthCheckRequest.FromString,
+                    response_serializer=node__service__pb2.HealthCheckResponse.SerializeToString,
+            ),
+    }
+    generic_handler = grpc.method_handlers_generic_handler(
+            'node_service.NodeService', rpc_method_handlers)
+    server.add_generic_rpc_handlers((generic_handler,))
+    server.add_registered_method_handlers('node_service.NodeService', rpc_method_handlers)
 
 
-# This class is part of an EXPERIMENTAL API.
+ # This class is part of an EXPERIMENTAL API.
 class NodeService(object):
-  """Missing associated documentation comment in .proto file."""
-  @staticmethod
-  def SendPrompt(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
-    return grpc.experimental.unary_unary(
-      request,
-      target,
-      '/node_service.NodeService/SendPrompt',
-      node__service__pb2.PromptRequest.SerializeToString,
-      node__service__pb2.Tensor.FromString,
-      options,
-      channel_credentials,
-      insecure,
-      call_credentials,
-      compression,
-      wait_for_ready,
-      timeout,
-      metadata,
-      _registered_method=True
-    )
+    """Missing associated documentation comment in .proto file."""
 
-  @staticmethod
-  def SendTensor(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
-    return grpc.experimental.unary_unary(
-      request,
-      target,
-      '/node_service.NodeService/SendTensor',
-      node__service__pb2.TensorRequest.SerializeToString,
-      node__service__pb2.Tensor.FromString,
-      options,
-      channel_credentials,
-      insecure,
-      call_credentials,
-      compression,
-      wait_for_ready,
-      timeout,
-      metadata,
-      _registered_method=True
-    )
+    @staticmethod
+    def SendPrompt(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/SendPrompt',
+            node__service__pb2.PromptRequest.SerializeToString,
+            node__service__pb2.Tensor.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
 
-  @staticmethod
-  def GetInferenceResult(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
-    return grpc.experimental.unary_unary(
-      request,
-      target,
-      '/node_service.NodeService/GetInferenceResult',
-      node__service__pb2.GetInferenceResultRequest.SerializeToString,
-      node__service__pb2.InferenceResult.FromString,
-      options,
-      channel_credentials,
-      insecure,
-      call_credentials,
-      compression,
-      wait_for_ready,
-      timeout,
-      metadata,
-      _registered_method=True
-    )
+    @staticmethod
+    def SendTensor(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/SendTensor',
+            node__service__pb2.TensorRequest.SerializeToString,
+            node__service__pb2.Tensor.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
 
-  @staticmethod
-  def CollectTopology(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
-    return grpc.experimental.unary_unary(
-      request,
-      target,
-      '/node_service.NodeService/CollectTopology',
-      node__service__pb2.CollectTopologyRequest.SerializeToString,
-      node__service__pb2.Topology.FromString,
-      options,
-      channel_credentials,
-      insecure,
-      call_credentials,
-      compression,
-      wait_for_ready,
-      timeout,
-      metadata,
-      _registered_method=True
-    )
+    @staticmethod
+    def GetInferenceResult(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/GetInferenceResult',
+            node__service__pb2.GetInferenceResultRequest.SerializeToString,
+            node__service__pb2.InferenceResult.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
 
-  @staticmethod
-  def SendResult(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
-    return grpc.experimental.unary_unary(
-      request,
-      target,
-      '/node_service.NodeService/SendResult',
-      node__service__pb2.SendResultRequest.SerializeToString,
-      node__service__pb2.Empty.FromString,
-      options,
-      channel_credentials,
-      insecure,
-      call_credentials,
-      compression,
-      wait_for_ready,
-      timeout,
-      metadata,
-      _registered_method=True
-    )
+    @staticmethod
+    def CollectTopology(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/CollectTopology',
+            node__service__pb2.CollectTopologyRequest.SerializeToString,
+            node__service__pb2.Topology.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
 
-  @staticmethod
-  def SendOpaqueStatus(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
-    return grpc.experimental.unary_unary(
-      request,
-      target,
-      '/node_service.NodeService/SendOpaqueStatus',
-      node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
-      node__service__pb2.Empty.FromString,
-      options,
-      channel_credentials,
-      insecure,
-      call_credentials,
-      compression,
-      wait_for_ready,
-      timeout,
-      metadata,
-      _registered_method=True
-    )
+    @staticmethod
+    def SendResult(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/SendResult',
+            node__service__pb2.SendResultRequest.SerializeToString,
+            node__service__pb2.Empty.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
+
+    @staticmethod
+    def SendOpaqueStatus(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/SendOpaqueStatus',
+            node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
+            node__service__pb2.Empty.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
+
+    @staticmethod
+    def HealthCheck(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/HealthCheck',
+            node__service__pb2.HealthCheckRequest.SerializeToString,
+            node__service__pb2.HealthCheckResponse.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)

+ 4 - 0
exo/networking/peer_handle.py

@@ -30,6 +30,10 @@ class PeerHandle(ABC):
   async def disconnect(self) -> None:
     pass
 
+  @abstractmethod
+  async def health_check(self) -> bool:
+    pass
+
   @abstractmethod
   async def send_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
     pass

+ 25 - 15
exo/networking/udp_discovery.py

@@ -139,15 +139,21 @@ class UDPDiscovery(Discovery):
       peer_host = addr[0]
       peer_port = message["grpc_port"]
       device_capabilities = DeviceCapabilities(**message["device_capabilities"])
-      if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
-        if DEBUG >= 1: print(
-          f"Adding {peer_id=} at {peer_host}:{peer_port}. Replace existing peer_id: {peer_id in self.known_peers}")
-        self.known_peers[peer_id] = (
-          self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities),
-          time.time(),
-          time.time(),
-        )
-      self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time())
+
+      # Create a new peer handle
+      new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
+
+      # Check if the new peer is healthy before adding
+      if await new_peer_handle.health_check():
+        if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
+          if DEBUG >= 1: print(
+            f"Adding {peer_id=} at {peer_host}:{peer_port}. Replace existing peer_id: {peer_id in self.known_peers}")
+          self.known_peers[peer_id] = (new_peer_handle, time.time(), time.time())
+        else:
+          # Update last seen time for existing peer
+          self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time())
+      else:
+        if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} failed health check. Not adding.")
 
   async def task_listen_for_peers(self):
     await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message),
@@ -158,14 +164,18 @@ class UDPDiscovery(Discovery):
     while True:
       try:
         current_time = time.time()
-        peers_to_remove = [
-          peer_handle.id() for peer_handle, connected_at, last_seen in self.known_peers.values()
-          if (not await peer_handle.is_connected() and current_time - connected_at > self.discovery_timeout) or current_time - last_seen > self.discovery_timeout
-        ]
-        if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}" for peer_handle, connected_at, last_seen in self.known_peers.values()})
+        peers_to_remove = []
+        for peer_id, (peer_handle, connected_at, last_seen) in self.known_peers.items():
+          if (not await peer_handle.is_connected() and current_time - connected_at > self.discovery_timeout) or \
+             current_time - last_seen > self.discovery_timeout or \
+             not await peer_handle.health_check():
+            peers_to_remove.append(peer_id)
+
+        if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, health_check={await peer_handle.health_check()}, {connected_at=}, {last_seen=}" for peer_handle, connected_at, last_seen in self.known_peers.values()})
+
         for peer_id in peers_to_remove:
           if peer_id in self.known_peers: del self.known_peers[peer_id]
-          if DEBUG_DISCOVERY >= 2: print(f"Removed peer {peer_id} due to inactivity.")
+          if DEBUG_DISCOVERY >= 2: print(f"Removed peer {peer_id} due to inactivity or failed health check.")
       except Exception as e:
         print(f"Error in cleanup peers: {e}")
         print(traceback.format_exc())

Some files were not shown because too many files changed in this diff