Explorar o código

Merge pull request #232 from exo-explore/health_checks

Health checks
Alex Cheema hai 10 meses
pai
achega
c379cfb204

+ 1 - 1
.circleci/config.yml

@@ -144,7 +144,7 @@ jobs:
             PID2=$!
             sleep 10
             kill $PID1 $PID2
-            if grep -q "Successfully connected peers: \['node2@.*:.*'\]" output1.log && ! grep -q "Failed to connect peers:" output1.log && grep -q "Successfully connected peers: \['node1@.*:.*'\]" output2.log && ! grep -q "Failed to connect peers:" output2.log; then
+            if grep -q "Peer statuses: {\\'node2\\': \\'is_connected=True, health_check=True" output1.log && ! grep -q "Failed to connect peers:" output1.log && grep -q "Peer statuses: {\\'node1\\': \\'is_connected=True, health_check=True" output2.log && ! grep -q "Failed to connect peers:" output2.log; then
               echo "Test passed: Both instances discovered each other"
               exit 0
             else

+ 23 - 3
exo/networking/grpc/grpc_peer_handle.py

@@ -1,8 +1,8 @@
 import grpc
 import numpy as np
+import asyncio
 from typing import Optional, Tuple, List
 
-# These would be generated from the .proto file
 from . import node_service_pb2
 from . import node_service_pb2_grpc
 
@@ -10,6 +10,7 @@ from ..peer_handle import PeerHandle
 from exo.inference.shard import Shard
 from exo.topology.topology import Topology
 from exo.topology.device_capabilities import DeviceCapabilities
+from exo.helpers import DEBUG
 
 
 class GRPCPeerHandle(PeerHandle):
@@ -30,8 +31,9 @@ class GRPCPeerHandle(PeerHandle):
     return self._device_capabilities
 
   async def connect(self):
-    self.channel = grpc.aio.insecure_channel(self.address, options=[("grpc.max_metadata_size", 32*1024*1024)])
-    self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
+    if self.channel is None:
+      self.channel = grpc.aio.insecure_channel(self.address, options=[("grpc.max_metadata_size", 32*1024*1024)])
+      self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
     await self.channel.channel_ready()
 
   async def is_connected(self) -> bool:
@@ -43,6 +45,24 @@ class GRPCPeerHandle(PeerHandle):
     self.channel = None
     self.stub = None
 
+  async def _ensure_connected(self):
+    if not await self.is_connected(): await asyncio.wait_for(self.connect(), timeout=5)
+
+  async def health_check(self) -> bool:
+    try:
+      await self._ensure_connected()
+      request = node_service_pb2.HealthCheckRequest()
+      response = await asyncio.wait_for(self.stub.HealthCheck(request), timeout=5)
+      return response.is_healthy
+    except asyncio.TimeoutError:
+      return False
+    except:
+      if DEBUG >= 4:
+        print(f"Health check failed for {self._id}@{self.address}.")
+        import traceback
+        traceback.print_exc()
+      return False
+
   async def send_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
     request = node_service_pb2.PromptRequest(
       prompt=prompt,

+ 3 - 0
exo/networking/grpc/grpc_server.py

@@ -116,3 +116,6 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     if DEBUG >= 8: print(f"Received SendOpaqueStatus request: {request_id=} {status=}")
     self.node.on_opaque_status.trigger_all(request_id, status)
     return node_service_pb2.Empty()
+
+  async def HealthCheck(self, request, context):
+    return node_service_pb2.HealthCheckResponse(is_healthy=True)

+ 7 - 0
exo/networking/grpc/node_service.proto

@@ -9,6 +9,7 @@ service NodeService {
   rpc CollectTopology (CollectTopologyRequest) returns (Topology) {}
   rpc SendResult (SendResultRequest) returns (Empty) {}
   rpc SendOpaqueStatus (SendOpaqueStatusRequest) returns (Empty) {}
+  rpc HealthCheck (HealthCheckRequest) returns (HealthCheckResponse) {}
 }
 
 message Shard {
@@ -86,4 +87,10 @@ message SendOpaqueStatusRequest {
   string status = 2;
 }
 
+message HealthCheckRequest {}
+
+message HealthCheckResponse {
+  bool is_healthy = 1;
+}
+
 message Empty {}

A diferenza do arquivo foi suprimida porque é demasiado grande
+ 0 - 1
exo/networking/grpc/node_service_pb2.py


+ 316 - 228
exo/networking/grpc/node_service_pb2_grpc.py

@@ -12,261 +12,349 @@ SCHEDULED_RELEASE_DATE = 'June 25, 2024'
 _version_not_supported = False
 
 try:
-  from grpc._utilities import first_version_is_lower
-  _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
+    from grpc._utilities import first_version_is_lower
+    _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
 except ImportError:
-  _version_not_supported = True
+    _version_not_supported = True
 
 if _version_not_supported:
-  warnings.warn(
-    f'The grpc package installed is at version {GRPC_VERSION},' + f' but the generated code in node_service_pb2_grpc.py depends on' + f' grpcio>={GRPC_GENERATED_VERSION}.' +
-    f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' +
-    f' This warning will become an error in {EXPECTED_ERROR_RELEASE},' + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.', RuntimeWarning
-  )
+    warnings.warn(
+        f'The grpc package installed is at version {GRPC_VERSION},'
+        + f' but the generated code in node_service_pb2_grpc.py depends on'
+        + f' grpcio>={GRPC_GENERATED_VERSION}.'
+        + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+        + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
+        + f' This warning will become an error in {EXPECTED_ERROR_RELEASE},'
+        + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.',
+        RuntimeWarning
+    )
 
 
 class NodeServiceStub(object):
-  """Missing associated documentation comment in .proto file."""
-  def __init__(self, channel):
-    """Constructor.
+    """Missing associated documentation comment in .proto file."""
+
+    def __init__(self, channel):
+        """Constructor.
 
         Args:
             channel: A grpc.Channel.
         """
-    self.SendPrompt = channel.unary_unary(
-      '/node_service.NodeService/SendPrompt',
-      request_serializer=node__service__pb2.PromptRequest.SerializeToString,
-      response_deserializer=node__service__pb2.Tensor.FromString,
-      _registered_method=True
-    )
-    self.SendTensor = channel.unary_unary(
-      '/node_service.NodeService/SendTensor',
-      request_serializer=node__service__pb2.TensorRequest.SerializeToString,
-      response_deserializer=node__service__pb2.Tensor.FromString,
-      _registered_method=True
-    )
-    self.GetInferenceResult = channel.unary_unary(
-      '/node_service.NodeService/GetInferenceResult',
-      request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
-      response_deserializer=node__service__pb2.InferenceResult.FromString,
-      _registered_method=True
-    )
-    self.CollectTopology = channel.unary_unary(
-      '/node_service.NodeService/CollectTopology',
-      request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
-      response_deserializer=node__service__pb2.Topology.FromString,
-      _registered_method=True
-    )
-    self.SendResult = channel.unary_unary(
-      '/node_service.NodeService/SendResult',
-      request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
-      response_deserializer=node__service__pb2.Empty.FromString,
-      _registered_method=True
-    )
-    self.SendOpaqueStatus = channel.unary_unary(
-      '/node_service.NodeService/SendOpaqueStatus',
-      request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
-      response_deserializer=node__service__pb2.Empty.FromString,
-      _registered_method=True
-    )
+        self.SendPrompt = channel.unary_unary(
+                '/node_service.NodeService/SendPrompt',
+                request_serializer=node__service__pb2.PromptRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Tensor.FromString,
+                _registered_method=True)
+        self.SendTensor = channel.unary_unary(
+                '/node_service.NodeService/SendTensor',
+                request_serializer=node__service__pb2.TensorRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Tensor.FromString,
+                _registered_method=True)
+        self.GetInferenceResult = channel.unary_unary(
+                '/node_service.NodeService/GetInferenceResult',
+                request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
+                response_deserializer=node__service__pb2.InferenceResult.FromString,
+                _registered_method=True)
+        self.CollectTopology = channel.unary_unary(
+                '/node_service.NodeService/CollectTopology',
+                request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Topology.FromString,
+                _registered_method=True)
+        self.SendResult = channel.unary_unary(
+                '/node_service.NodeService/SendResult',
+                request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Empty.FromString,
+                _registered_method=True)
+        self.SendOpaqueStatus = channel.unary_unary(
+                '/node_service.NodeService/SendOpaqueStatus',
+                request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Empty.FromString,
+                _registered_method=True)
+        self.HealthCheck = channel.unary_unary(
+                '/node_service.NodeService/HealthCheck',
+                request_serializer=node__service__pb2.HealthCheckRequest.SerializeToString,
+                response_deserializer=node__service__pb2.HealthCheckResponse.FromString,
+                _registered_method=True)
 
 
 class NodeServiceServicer(object):
-  """Missing associated documentation comment in .proto file."""
-  def SendPrompt(self, request, context):
     """Missing associated documentation comment in .proto file."""
-    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-    context.set_details('Method not implemented!')
-    raise NotImplementedError('Method not implemented!')
 
-  def SendTensor(self, request, context):
-    """Missing associated documentation comment in .proto file."""
-    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-    context.set_details('Method not implemented!')
-    raise NotImplementedError('Method not implemented!')
+    def SendPrompt(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
 
-  def GetInferenceResult(self, request, context):
-    """Missing associated documentation comment in .proto file."""
-    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-    context.set_details('Method not implemented!')
-    raise NotImplementedError('Method not implemented!')
+    def SendTensor(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
 
-  def CollectTopology(self, request, context):
-    """Missing associated documentation comment in .proto file."""
-    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-    context.set_details('Method not implemented!')
-    raise NotImplementedError('Method not implemented!')
+    def GetInferenceResult(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
 
-  def SendResult(self, request, context):
-    """Missing associated documentation comment in .proto file."""
-    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-    context.set_details('Method not implemented!')
-    raise NotImplementedError('Method not implemented!')
+    def CollectTopology(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
 
-  def SendOpaqueStatus(self, request, context):
-    """Missing associated documentation comment in .proto file."""
-    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-    context.set_details('Method not implemented!')
-    raise NotImplementedError('Method not implemented!')
+    def SendResult(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
+
+    def SendOpaqueStatus(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
+
+    def HealthCheck(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
 
 
 def add_NodeServiceServicer_to_server(servicer, server):
-  rpc_method_handlers = {
-    'SendPrompt':
-      grpc.unary_unary_rpc_method_handler(
-        servicer.SendPrompt,
-        request_deserializer=node__service__pb2.PromptRequest.FromString,
-        response_serializer=node__service__pb2.Tensor.SerializeToString,
-      ),
-    'SendTensor':
-      grpc.unary_unary_rpc_method_handler(
-        servicer.SendTensor,
-        request_deserializer=node__service__pb2.TensorRequest.FromString,
-        response_serializer=node__service__pb2.Tensor.SerializeToString,
-      ),
-    'GetInferenceResult':
-      grpc.unary_unary_rpc_method_handler(
-        servicer.GetInferenceResult,
-        request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
-        response_serializer=node__service__pb2.InferenceResult.SerializeToString,
-      ),
-    'CollectTopology':
-      grpc.unary_unary_rpc_method_handler(
-        servicer.CollectTopology,
-        request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
-        response_serializer=node__service__pb2.Topology.SerializeToString,
-      ),
-    'SendResult':
-      grpc.unary_unary_rpc_method_handler(
-        servicer.SendResult,
-        request_deserializer=node__service__pb2.SendResultRequest.FromString,
-        response_serializer=node__service__pb2.Empty.SerializeToString,
-      ),
-    'SendOpaqueStatus':
-      grpc.unary_unary_rpc_method_handler(
-        servicer.SendOpaqueStatus,
-        request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
-        response_serializer=node__service__pb2.Empty.SerializeToString,
-      ),
-  }
-  generic_handler = grpc.method_handlers_generic_handler('node_service.NodeService', rpc_method_handlers)
-  server.add_generic_rpc_handlers((generic_handler,))
-  server.add_registered_method_handlers('node_service.NodeService', rpc_method_handlers)
+    rpc_method_handlers = {
+            'SendPrompt': grpc.unary_unary_rpc_method_handler(
+                    servicer.SendPrompt,
+                    request_deserializer=node__service__pb2.PromptRequest.FromString,
+                    response_serializer=node__service__pb2.Tensor.SerializeToString,
+            ),
+            'SendTensor': grpc.unary_unary_rpc_method_handler(
+                    servicer.SendTensor,
+                    request_deserializer=node__service__pb2.TensorRequest.FromString,
+                    response_serializer=node__service__pb2.Tensor.SerializeToString,
+            ),
+            'GetInferenceResult': grpc.unary_unary_rpc_method_handler(
+                    servicer.GetInferenceResult,
+                    request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
+                    response_serializer=node__service__pb2.InferenceResult.SerializeToString,
+            ),
+            'CollectTopology': grpc.unary_unary_rpc_method_handler(
+                    servicer.CollectTopology,
+                    request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
+                    response_serializer=node__service__pb2.Topology.SerializeToString,
+            ),
+            'SendResult': grpc.unary_unary_rpc_method_handler(
+                    servicer.SendResult,
+                    request_deserializer=node__service__pb2.SendResultRequest.FromString,
+                    response_serializer=node__service__pb2.Empty.SerializeToString,
+            ),
+            'SendOpaqueStatus': grpc.unary_unary_rpc_method_handler(
+                    servicer.SendOpaqueStatus,
+                    request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
+                    response_serializer=node__service__pb2.Empty.SerializeToString,
+            ),
+            'HealthCheck': grpc.unary_unary_rpc_method_handler(
+                    servicer.HealthCheck,
+                    request_deserializer=node__service__pb2.HealthCheckRequest.FromString,
+                    response_serializer=node__service__pb2.HealthCheckResponse.SerializeToString,
+            ),
+    }
+    generic_handler = grpc.method_handlers_generic_handler(
+            'node_service.NodeService', rpc_method_handlers)
+    server.add_generic_rpc_handlers((generic_handler,))
+    server.add_registered_method_handlers('node_service.NodeService', rpc_method_handlers)
 
 
-# This class is part of an EXPERIMENTAL API.
+ # This class is part of an EXPERIMENTAL API.
 class NodeService(object):
-  """Missing associated documentation comment in .proto file."""
-  @staticmethod
-  def SendPrompt(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
-    return grpc.experimental.unary_unary(
-      request,
-      target,
-      '/node_service.NodeService/SendPrompt',
-      node__service__pb2.PromptRequest.SerializeToString,
-      node__service__pb2.Tensor.FromString,
-      options,
-      channel_credentials,
-      insecure,
-      call_credentials,
-      compression,
-      wait_for_ready,
-      timeout,
-      metadata,
-      _registered_method=True
-    )
+    """Missing associated documentation comment in .proto file."""
 
-  @staticmethod
-  def SendTensor(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
-    return grpc.experimental.unary_unary(
-      request,
-      target,
-      '/node_service.NodeService/SendTensor',
-      node__service__pb2.TensorRequest.SerializeToString,
-      node__service__pb2.Tensor.FromString,
-      options,
-      channel_credentials,
-      insecure,
-      call_credentials,
-      compression,
-      wait_for_ready,
-      timeout,
-      metadata,
-      _registered_method=True
-    )
+    @staticmethod
+    def SendPrompt(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/SendPrompt',
+            node__service__pb2.PromptRequest.SerializeToString,
+            node__service__pb2.Tensor.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
 
-  @staticmethod
-  def GetInferenceResult(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
-    return grpc.experimental.unary_unary(
-      request,
-      target,
-      '/node_service.NodeService/GetInferenceResult',
-      node__service__pb2.GetInferenceResultRequest.SerializeToString,
-      node__service__pb2.InferenceResult.FromString,
-      options,
-      channel_credentials,
-      insecure,
-      call_credentials,
-      compression,
-      wait_for_ready,
-      timeout,
-      metadata,
-      _registered_method=True
-    )
+    @staticmethod
+    def SendTensor(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/SendTensor',
+            node__service__pb2.TensorRequest.SerializeToString,
+            node__service__pb2.Tensor.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
 
-  @staticmethod
-  def CollectTopology(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
-    return grpc.experimental.unary_unary(
-      request,
-      target,
-      '/node_service.NodeService/CollectTopology',
-      node__service__pb2.CollectTopologyRequest.SerializeToString,
-      node__service__pb2.Topology.FromString,
-      options,
-      channel_credentials,
-      insecure,
-      call_credentials,
-      compression,
-      wait_for_ready,
-      timeout,
-      metadata,
-      _registered_method=True
-    )
+    @staticmethod
+    def GetInferenceResult(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/GetInferenceResult',
+            node__service__pb2.GetInferenceResultRequest.SerializeToString,
+            node__service__pb2.InferenceResult.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
 
-  @staticmethod
-  def SendResult(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
-    return grpc.experimental.unary_unary(
-      request,
-      target,
-      '/node_service.NodeService/SendResult',
-      node__service__pb2.SendResultRequest.SerializeToString,
-      node__service__pb2.Empty.FromString,
-      options,
-      channel_credentials,
-      insecure,
-      call_credentials,
-      compression,
-      wait_for_ready,
-      timeout,
-      metadata,
-      _registered_method=True
-    )
+    @staticmethod
+    def CollectTopology(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/CollectTopology',
+            node__service__pb2.CollectTopologyRequest.SerializeToString,
+            node__service__pb2.Topology.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
 
-  @staticmethod
-  def SendOpaqueStatus(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
-    return grpc.experimental.unary_unary(
-      request,
-      target,
-      '/node_service.NodeService/SendOpaqueStatus',
-      node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
-      node__service__pb2.Empty.FromString,
-      options,
-      channel_credentials,
-      insecure,
-      call_credentials,
-      compression,
-      wait_for_ready,
-      timeout,
-      metadata,
-      _registered_method=True
-    )
+    @staticmethod
+    def SendResult(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/SendResult',
+            node__service__pb2.SendResultRequest.SerializeToString,
+            node__service__pb2.Empty.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
+
+    @staticmethod
+    def SendOpaqueStatus(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/SendOpaqueStatus',
+            node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
+            node__service__pb2.Empty.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
+
+    @staticmethod
+    def HealthCheck(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/HealthCheck',
+            node__service__pb2.HealthCheckRequest.SerializeToString,
+            node__service__pb2.HealthCheckResponse.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)

+ 4 - 0
exo/networking/peer_handle.py

@@ -30,6 +30,10 @@ class PeerHandle(ABC):
   async def disconnect(self) -> None:
     pass
 
+  @abstractmethod
+  async def health_check(self) -> bool:
+    pass
+
   @abstractmethod
   async def send_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
     pass

+ 12 - 3
exo/networking/tailscale_discovery.py

@@ -87,13 +87,22 @@ class TailscaleDiscovery(Discovery):
             continue
 
           if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
+            new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
+            if not await new_peer_handle.health_check():
+              if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} is not healthy. Skipping.")
+              continue
+
             if DEBUG >= 1: print(f"Adding {peer_id=} at {peer_host}:{peer_port}. Replace existing peer_id: {peer_id in self.known_peers}")
             self.known_peers[peer_id] = (
-              self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities),
+              new_peer_handle,
               current_time,
               current_time,
             )
           else:
+            if not await self.known_peers[peer_id][0].health_check():
+              if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} is not healthy. Removing.")
+              if peer_id in self.known_peers: del self.known_peers[peer_id]
+              continue
             self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], current_time)
 
       except Exception as e:
@@ -126,9 +135,9 @@ class TailscaleDiscovery(Discovery):
         current_time = time.time()
         peers_to_remove = [
           peer_handle.id() for peer_handle, connected_at, last_seen in self.known_peers.values()
-          if (not await peer_handle.is_connected() and current_time - connected_at > self.discovery_timeout) or current_time - last_seen > self.discovery_timeout
+          if (not await peer_handle.is_connected() and current_time - connected_at > self.discovery_timeout) or current_time - last_seen > self.discovery_timeout or not await peer_handle.health_check()
         ]
-        if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}" for peer_handle, connected_at, last_seen in self.known_peers.values()})
+        if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}, health_check={await peer_handle.health_check()}" for peer_handle, connected_at, last_seen in self.known_peers.values()})
         for peer_id in peers_to_remove:
           if peer_id in self.known_peers: del self.known_peers[peer_id]
           if DEBUG_DISCOVERY >= 2: print(f"Removed peer {peer_id} due to inactivity.")

+ 23 - 14
exo/networking/udp_discovery.py

@@ -139,15 +139,20 @@ class UDPDiscovery(Discovery):
       peer_host = addr[0]
       peer_port = message["grpc_port"]
       device_capabilities = DeviceCapabilities(**message["device_capabilities"])
+
       if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
-        if DEBUG >= 1: print(
-          f"Adding {peer_id=} at {peer_host}:{peer_port}. Replace existing peer_id: {peer_id in self.known_peers}")
-        self.known_peers[peer_id] = (
-          self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities),
-          time.time(),
-          time.time(),
-        )
-      self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time())
+        new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
+        if not await new_peer_handle.health_check():
+          if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} is not healthy. Skipping.")
+          return
+        if DEBUG >= 1: print(f"Adding {peer_id=} at {peer_host}:{peer_port}. Replace existing peer_id: {peer_id in self.known_peers}")
+        self.known_peers[peer_id] = (new_peer_handle, time.time(), time.time())
+      else:
+        if not await self.known_peers[peer_id][0].health_check():
+          if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} is not healthy. Removing.")
+          if peer_id in self.known_peers: del self.known_peers[peer_id]
+          return
+        self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time())
 
   async def task_listen_for_peers(self):
     await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message),
@@ -158,14 +163,18 @@ class UDPDiscovery(Discovery):
     while True:
       try:
         current_time = time.time()
-        peers_to_remove = [
-          peer_handle.id() for peer_handle, connected_at, last_seen in self.known_peers.values()
-          if (not await peer_handle.is_connected() and current_time - connected_at > self.discovery_timeout) or current_time - last_seen > self.discovery_timeout
-        ]
-        if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}" for peer_handle, connected_at, last_seen in self.known_peers.values()})
+        peers_to_remove = []
+        for peer_id, (peer_handle, connected_at, last_seen) in self.known_peers.items():
+          if (not await peer_handle.is_connected() and current_time - connected_at > self.discovery_timeout) or \
+             (current_time - last_seen > self.discovery_timeout) or \
+             (not await peer_handle.health_check()):
+            peers_to_remove.append(peer_id)
+
+        if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, health_check={await peer_handle.health_check()}, {connected_at=}, {last_seen=}" for peer_handle, connected_at, last_seen in self.known_peers.values()})
+
         for peer_id in peers_to_remove:
           if peer_id in self.known_peers: del self.known_peers[peer_id]
-          if DEBUG_DISCOVERY >= 2: print(f"Removed peer {peer_id} due to inactivity.")
+          if DEBUG_DISCOVERY >= 2: print(f"Removed peer {peer_id} due to inactivity or failed health check.")
       except Exception as e:
         print(f"Error in cleanup peers: {e}")
         print(traceback.format_exc())

+ 1 - 1
exo/orchestration/standard_node.py

@@ -336,7 +336,7 @@ class StandardNode(Node):
       if failed_connects: print(f"Failed to connect peers: {_pretty(failed_connects)}")
 
     self.peers = next_peers
-    return len(peers_to_connect) > 0 or len(peers_to_disconnect) > 0
+    return len(peers_added) > 0 or len(peers_removed) > 0 or len(peers_updated) > 0
 
   async def periodic_topology_collection(self, interval: int):
     while True:

Algúns arquivos non se mostraron porque demasiados arquivos cambiaron neste cambio