Browse Source

Fix ref-counting in DisruptableMockTransport (#92245)

Today `DisruptableMockTransport` leaks refs to transport messages in
various ways if the transport is rebooted. This commit adds the missing
ref-count handling.

Closes #91837
David Turner 2 years ago
parent
commit
f8636c6313

+ 1 - 1
libs/core/src/main/java/org/elasticsearch/core/AbstractRefCounted.java

@@ -47,7 +47,7 @@ public abstract class AbstractRefCounted implements RefCounted {
     public final boolean decRef() {
         touch();
         int i = refCount.decrementAndGet();
-        assert i >= 0;
+        assert i >= 0 : "invalid decRef call: already closed";
         if (i == 0) {
             try {
                 closeInternal();

+ 33 - 0
server/src/main/java/org/elasticsearch/cluster/coordination/CleanableResponseHandler.java

@@ -0,0 +1,33 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.cluster.coordination;
+
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.ActionListenerResponseHandler;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.transport.TransportResponse;
+
+/**
+ * Combines an ActionListenerResponseHandler with an ActionListener.runAfter action, but with an explicit type so that tests that simulate
+ * reboots can release resources without invoking the listener.
+ */
+public class CleanableResponseHandler<T extends TransportResponse> extends ActionListenerResponseHandler<T> {
+    private final Runnable cleanup;
+
+    public CleanableResponseHandler(ActionListener<? super T> listener, Writeable.Reader<T> reader, String executor, Runnable cleanup) {
+        super(ActionListener.runAfter(listener, cleanup), reader, executor);
+        this.cleanup = cleanup;
+    }
+
+    public void runCleanup() {
+        assert ThreadPool.assertCurrentThreadPool(); // should only be called from tests which simulate abrupt node restarts
+        cleanup.run();
+    }
+}

+ 4 - 3
server/src/main/java/org/elasticsearch/cluster/coordination/JoinValidationService.java

@@ -310,10 +310,11 @@ public class JoinValidationService {
                 JOIN_VALIDATE_ACTION_NAME,
                 new BytesTransportRequest(bytes, discoveryNode.getVersion()),
                 REQUEST_OPTIONS,
-                new ActionListenerResponseHandler<>(
-                    ActionListener.runAfter(listener, bytes::decRef),
+                new CleanableResponseHandler<>(
+                    listener,
                     in -> TransportResponse.Empty.INSTANCE,
-                    ThreadPool.Names.CLUSTER_COORDINATION
+                    ThreadPool.Names.CLUSTER_COORDINATION,
+                    bytes::decRef
                 )
             );
             if (cachedBytes == null) {

+ 1 - 6
server/src/main/java/org/elasticsearch/cluster/coordination/PublicationTransportHandler.java

@@ -12,7 +12,6 @@ import org.apache.logging.log4j.Logger;
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.action.ActionListenerResponseHandler;
 import org.elasticsearch.action.ActionRunnable;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.ClusterStatePublicationEvent;
@@ -470,11 +469,7 @@ public class PublicationTransportHandler {
                 new BytesTransportRequest(bytes, destination.getVersion()),
                 task,
                 STATE_REQUEST_OPTIONS,
-                new ActionListenerResponseHandler<>(
-                    ActionListener.runAfter(listener, bytes::decRef),
-                    PublishWithJoinResponse::new,
-                    ThreadPool.Names.CLUSTER_COORDINATION
-                )
+                new CleanableResponseHandler<>(listener, PublishWithJoinResponse::new, ThreadPool.Names.CLUSTER_COORDINATION, bytes::decRef)
             );
         }
 

+ 4 - 0
server/src/main/java/org/elasticsearch/transport/TransportService.java

@@ -1387,6 +1387,10 @@ public class TransportService extends AbstractLifecycleComponent
             this.handler = timeoutHandler;
         }
 
+        // for tests
+        TransportResponseHandler<T> unwrap() {
+            return delegate;
+        }
     }
 
     static class DirectResponseChannel implements TransportChannel {

+ 1 - 2
server/src/test/java/org/elasticsearch/cluster/coordination/CoordinatorTests.java

@@ -1634,8 +1634,7 @@ public class CoordinatorTests extends AbstractCoordinatorTestCase {
         reason = "test includes assertions about JoinHelper logging",
         value = "org.elasticsearch.cluster.coordination.JoinHelper:INFO"
     )
-    @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/91837")
-    public void testCannotJoinClusterWithDifferentUUID() throws IllegalAccessException {
+    public void testCannotJoinClusterWithDifferentUUID() {
         try (Cluster cluster1 = new Cluster(randomIntBetween(1, 3))) {
             cluster1.runRandomly();
             cluster1.stabilise();

+ 2 - 2
server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java

@@ -174,10 +174,10 @@ import org.elasticsearch.search.SearchService;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.search.fetch.FetchPhase;
 import org.elasticsearch.test.ESTestCase;
-import org.elasticsearch.test.disruption.DisruptableMockTransport;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.tracing.Tracer;
 import org.elasticsearch.transport.BytesRefRecycler;
+import org.elasticsearch.transport.DisruptableMockTransport;
 import org.elasticsearch.transport.TransportService;
 import org.elasticsearch.xcontent.NamedXContentRegistry;
 import org.junit.After;
@@ -1652,7 +1652,7 @@ public class SnapshotResiliencyTests extends ESTestCase {
                     }
                 );
                 recoverySettings = new RecoverySettings(settings, clusterSettings);
-                mockTransport = new DisruptableMockTransport(node, logger, deterministicTaskQueue) {
+                mockTransport = new DisruptableMockTransport(node, deterministicTaskQueue) {
                     @Override
                     protected ConnectionStatus getConnectionStatus(DiscoveryNode destination) {
                         if (node.equals(destination)) {

+ 5 - 5
test/framework/src/main/java/org/elasticsearch/cluster/coordination/AbstractCoordinatorTestCase.java

@@ -71,11 +71,11 @@ import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
 import org.elasticsearch.monitor.NodeHealthService;
 import org.elasticsearch.monitor.StatusInfo;
 import org.elasticsearch.test.ESTestCase;
-import org.elasticsearch.test.disruption.DisruptableMockTransport;
-import org.elasticsearch.test.disruption.DisruptableMockTransport.ConnectionStatus;
 import org.elasticsearch.threadpool.Scheduler;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.BytesRefRecycler;
+import org.elasticsearch.transport.DisruptableMockTransport;
+import org.elasticsearch.transport.DisruptableMockTransport.ConnectionStatus;
 import org.elasticsearch.transport.TransportInterceptor;
 import org.elasticsearch.transport.TransportRequest;
 import org.elasticsearch.transport.TransportRequestOptions;
@@ -1140,7 +1140,7 @@ public class AbstractCoordinatorTestCase extends ESTestCase {
             private void setUp() {
                 final ThreadPool threadPool = deterministicTaskQueue.getThreadPool(this::onNode);
                 clearableRecycler = new ClearableRecycler(recycler);
-                mockTransport = new DisruptableMockTransport(localNode, logger, deterministicTaskQueue) {
+                mockTransport = new DisruptableMockTransport(localNode, deterministicTaskQueue) {
                     @Override
                     protected void execute(Runnable runnable) {
                         deterministicTaskQueue.scheduleNow(onNode(runnable));
@@ -1395,13 +1395,13 @@ public class AbstractCoordinatorTestCase extends ESTestCase {
                     public void run() {
                         if (clusterNodes.contains(ClusterNode.this)) {
                             wrapped.run();
-                        } else if (runnable instanceof DisruptableMockTransport.RebootSensitiveRunnable) {
+                        } else if (runnable instanceof DisruptableMockTransport.RebootSensitiveRunnable rebootSensitiveRunnable) {
                             logger.trace(
                                 "completing reboot-sensitive runnable {} from node {} as node has been removed from cluster",
                                 runnable,
                                 localNode
                             );
-                            ((DisruptableMockTransport.RebootSensitiveRunnable) runnable).ifRebooted();
+                            rebootSensitiveRunnable.ifRebooted();
                         } else {
                             logger.trace("ignoring runnable {} from node {} as node has been removed from cluster", runnable, localNode);
                         }

+ 14 - 4
test/framework/src/main/java/org/elasticsearch/test/transport/MockTransport.java

@@ -89,8 +89,7 @@ public class MockTransport extends StubbableTransport {
      */
     @SuppressWarnings("unchecked")
     public <Response extends TransportResponse> void handleResponse(final long requestId, final Response response) {
-        final TransportResponseHandler<Response> transportResponseHandler = (TransportResponseHandler<Response>) getResponseHandlers()
-            .onResponseReceived(requestId, listener);
+        final TransportResponseHandler<Response> transportResponseHandler = getTransportResponseHandler(requestId);
         if (transportResponseHandler != null) {
             final Response deliveredResponse;
             try (BytesStreamOutput output = new BytesStreamOutput()) {
@@ -100,8 +99,14 @@ public class MockTransport extends StubbableTransport {
                 );
             } catch (IOException | UnsupportedOperationException e) {
                 throw new AssertionError("failed to serialize/deserialize response " + response, e);
+            } finally {
+                response.decRef();
+            }
+            try {
+                transportResponseHandler.handleResponse(deliveredResponse);
+            } finally {
+                deliveredResponse.decRef();
             }
-            transportResponseHandler.handleResponse(deliveredResponse);
         }
     }
 
@@ -154,12 +159,17 @@ public class MockTransport extends StubbableTransport {
      * @param e         the failure
      */
     public void handleError(final long requestId, final TransportException e) {
-        final TransportResponseHandler<?> transportResponseHandler = getResponseHandlers().onResponseReceived(requestId, listener);
+        final TransportResponseHandler<?> transportResponseHandler = getTransportResponseHandler(requestId);
         if (transportResponseHandler != null) {
             transportResponseHandler.handleException(e);
         }
     }
 
+    @SuppressWarnings("unchecked")
+    public <T extends TransportResponse> TransportResponseHandler<T> getTransportResponseHandler(long requestId) {
+        return (TransportResponseHandler<T>) getResponseHandlers().onResponseReceived(requestId, listener);
+    }
+
     public Connection createConnection(DiscoveryNode node) {
         return new CloseableConnection() {
             @Override

+ 51 - 38
test/framework/src/main/java/org/elasticsearch/test/disruption/DisruptableMockTransport.java → test/framework/src/main/java/org/elasticsearch/transport/DisruptableMockTransport.java

@@ -5,11 +5,13 @@
  * in compliance with, at your election, the Elastic License 2.0 or the Server
  * Side Public License, v 1.
  */
-package org.elasticsearch.test.disruption;
+package org.elasticsearch.transport;
 
+import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.cluster.coordination.CleanableResponseHandler;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Settings;
@@ -21,19 +23,6 @@ import org.elasticsearch.tasks.TaskManager;
 import org.elasticsearch.test.transport.MockTransport;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.tracing.Tracer;
-import org.elasticsearch.transport.CloseableConnection;
-import org.elasticsearch.transport.ConnectTransportException;
-import org.elasticsearch.transport.ConnectionProfile;
-import org.elasticsearch.transport.NodeNotConnectedException;
-import org.elasticsearch.transport.RemoteTransportException;
-import org.elasticsearch.transport.RequestHandlerRegistry;
-import org.elasticsearch.transport.TransportChannel;
-import org.elasticsearch.transport.TransportException;
-import org.elasticsearch.transport.TransportInterceptor;
-import org.elasticsearch.transport.TransportRequest;
-import org.elasticsearch.transport.TransportRequestOptions;
-import org.elasticsearch.transport.TransportResponse;
-import org.elasticsearch.transport.TransportService;
 
 import java.io.IOException;
 import java.util.ArrayList;
@@ -48,14 +37,13 @@ import static org.elasticsearch.test.ESTestCase.copyWriteable;
 
 public abstract class DisruptableMockTransport extends MockTransport {
     private final DiscoveryNode localNode;
-    private final Logger logger;
+    private final Logger logger = LogManager.getLogger(DisruptableMockTransport.class);
     private final DeterministicTaskQueue deterministicTaskQueue;
     private final List<Runnable> blackholedRequests = new ArrayList<>();
     private final Set<String> blockedActions = new HashSet<>();
 
-    public DisruptableMockTransport(DiscoveryNode localNode, Logger logger, DeterministicTaskQueue deterministicTaskQueue) {
+    public DisruptableMockTransport(DiscoveryNode localNode, DeterministicTaskQueue deterministicTaskQueue) {
         this.localNode = localNode;
-        this.logger = logger;
         this.deterministicTaskQueue = deterministicTaskQueue;
     }
 
@@ -111,7 +99,12 @@ public abstract class DisruptableMockTransport extends MockTransport {
                     public void sendRequest(long requestId, String action, TransportRequest request, TransportRequestOptions options)
                         throws TransportException {
                         if (blockedActions.contains(action)) {
-                            execute(new Runnable() {
+                            execute(new RebootSensitiveRunnable() {
+                                @Override
+                                public void ifRebooted() {
+                                    cleanupResponseHandler(requestId);
+                                }
+
                                 @Override
                                 public void run() {
                                     handleError(
@@ -127,7 +120,7 @@ public abstract class DisruptableMockTransport extends MockTransport {
 
                                 @Override
                                 public String toString() {
-                                    return "error response delivery for action [" + action + "] on node [" + node + "]";
+                                    return "error response delivery for blocked action [" + action + "] on node [" + node + "]";
                                 }
                             });
                         } else {
@@ -172,28 +165,20 @@ public abstract class DisruptableMockTransport extends MockTransport {
             @Override
             public void ifRebooted() {
                 request.decRef();
-                deterministicTaskQueue.scheduleNow(new Runnable() {
+                execute(new RebootSensitiveRunnable() {
                     @Override
-                    public void run() {
-                        execute(new Runnable() {
-                            @Override
-                            public void run() {
-                                handleRemoteError(
-                                    requestId,
-                                    new NodeNotConnectedException(destinationTransport.getLocalNode(), "node rebooted")
-                                );
-                            }
+                    public void ifRebooted() {
+                        cleanupResponseHandler(requestId);
+                    }
 
-                            @Override
-                            public String toString() {
-                                return "error response (reboot) to " + internalToString();
-                            }
-                        });
+                    @Override
+                    public void run() {
+                        handleRemoteError(requestId, new NodeNotConnectedException(destinationTransport.getLocalNode(), "node rebooted"));
                     }
 
                     @Override
                     public String toString() {
-                        return "scheduling of error response (reboot) to " + internalToString();
+                        return "error response (reboot) to " + internalToString();
                     }
                 });
             }
@@ -210,7 +195,12 @@ public abstract class DisruptableMockTransport extends MockTransport {
     }
 
     protected Runnable getDisconnectException(long requestId, String action, DiscoveryNode destination) {
-        return new Runnable() {
+        return new RebootSensitiveRunnable() {
+            @Override
+            public void ifRebooted() {
+                cleanupResponseHandler(requestId);
+            }
+
             @Override
             public void run() {
                 handleError(requestId, new ConnectTransportException(destination, "disconnected"));
@@ -272,13 +262,20 @@ public abstract class DisruptableMockTransport extends MockTransport {
 
             @Override
             public void sendResponse(final TransportResponse response) {
-                execute(new Runnable() {
+                execute(new RebootSensitiveRunnable() {
+                    @Override
+                    public void ifRebooted() {
+                        response.decRef();
+                        cleanupResponseHandler(requestId);
+                    }
+
                     @Override
                     public void run() {
                         final ConnectionStatus connectionStatus = destinationTransport.getConnectionStatus(getLocalNode());
                         switch (connectionStatus) {
                             case CONNECTED, BLACK_HOLE_REQUESTS_ONLY -> handleResponse(requestId, response);
                             case BLACK_HOLE, DISCONNECTED -> {
+                                response.decRef();
                                 logger.trace("delaying response to {}: channel is {}", requestDescription, connectionStatus);
                                 onBlackholedDuringSend(requestId, action, destinationTransport);
                             }
@@ -295,8 +292,12 @@ public abstract class DisruptableMockTransport extends MockTransport {
 
             @Override
             public void sendResponse(Exception exception) {
+                execute(new RebootSensitiveRunnable() {
+                    @Override
+                    public void ifRebooted() {
+                        cleanupResponseHandler(requestId);
+                    }
 
-                execute(new Runnable() {
                     @Override
                     public void run() {
                         final ConnectionStatus connectionStatus = destinationTransport.getConnectionStatus(getLocalNode());
@@ -333,6 +334,18 @@ public abstract class DisruptableMockTransport extends MockTransport {
             } catch (Exception ee) {
                 logger.warn("failed to send failure", e);
             }
+        } finally {
+            copiedRequest.decRef();
+        }
+    }
+
+    private void cleanupResponseHandler(long requestId) {
+        TransportResponseHandler<?> handler = getTransportResponseHandler(requestId);
+        while (handler instanceof TransportService.ContextRestoreResponseHandler<?> contextRestoreHandler) {
+            handler = contextRestoreHandler.unwrap();
+        }
+        if (handler instanceof CleanableResponseHandler<?> cleanableResponseHandler) {
+            cleanableResponseHandler.runCleanup();
         }
     }
 

+ 258 - 35
test/framework/src/test/java/org/elasticsearch/test/disruption/DisruptableMockTransportTests.java → test/framework/src/test/java/org/elasticsearch/transport/DisruptableMockTransportTests.java

@@ -6,29 +6,27 @@
  * Side Public License, v 1.
  */
 
-package org.elasticsearch.test.disruption;
+package org.elasticsearch.transport;
 
+import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.Version;
+import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.cluster.coordination.CleanableResponseHandler;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.transport.TransportAddress;
 import org.elasticsearch.common.util.concurrent.DeterministicTaskQueue;
+import org.elasticsearch.core.AbstractRefCounted;
+import org.elasticsearch.core.RefCounted;
 import org.elasticsearch.core.Releasable;
 import org.elasticsearch.core.Tuple;
 import org.elasticsearch.test.ESTestCase;
-import org.elasticsearch.test.disruption.DisruptableMockTransport.ConnectionStatus;
 import org.elasticsearch.threadpool.ThreadPool;
-import org.elasticsearch.transport.AbstractSimpleTransportTestCase;
-import org.elasticsearch.transport.ConnectTransportException;
-import org.elasticsearch.transport.TransportChannel;
-import org.elasticsearch.transport.TransportException;
-import org.elasticsearch.transport.TransportRequest;
-import org.elasticsearch.transport.TransportRequestHandler;
-import org.elasticsearch.transport.TransportResponse;
-import org.elasticsearch.transport.TransportResponseHandler;
-import org.elasticsearch.transport.TransportService;
+import org.elasticsearch.transport.DisruptableMockTransport.ConnectionStatus;
+import org.junit.After;
 import org.junit.Before;
 
 import java.io.IOException;
@@ -41,6 +39,7 @@ import java.util.Set;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Consumer;
+import java.util.stream.Stream;
 
 import static org.elasticsearch.transport.TransportService.NOOP_TRANSPORT_INTERCEPTOR;
 import static org.hamcrest.Matchers.containsString;
@@ -49,6 +48,8 @@ import static org.hamcrest.Matchers.instanceOf;
 
 public class DisruptableMockTransportTests extends ESTestCase {
 
+    private static final String TEST_ACTION = "internal:dummy";
+
     private DiscoveryNode node1;
     private DiscoveryNode node2;
 
@@ -58,11 +59,15 @@ public class DisruptableMockTransportTests extends ESTestCase {
     private DeterministicTaskQueue deterministicTaskQueue;
 
     private Runnable deliverBlackholedRequests;
+    private Runnable blockTestAction;
 
     private Set<Tuple<DiscoveryNode, DiscoveryNode>> disconnectedLinks;
     private Set<Tuple<DiscoveryNode, DiscoveryNode>> blackholedLinks;
     private Set<Tuple<DiscoveryNode, DiscoveryNode>> blackholedRequestLinks;
 
+    private long activeRequestCount;
+    private final Set<DiscoveryNode> rebootedNodes = new HashSet<>();
+
     private ConnectionStatus getConnectionStatus(DiscoveryNode sender, DiscoveryNode destination) {
         Tuple<DiscoveryNode, DiscoveryNode> link = Tuple.tuple(sender, destination);
         if (disconnectedLinks.contains(link)) {
@@ -93,7 +98,7 @@ public class DisruptableMockTransportTests extends ESTestCase {
 
         deterministicTaskQueue = new DeterministicTaskQueue();
 
-        final DisruptableMockTransport transport1 = new DisruptableMockTransport(node1, logger, deterministicTaskQueue) {
+        final DisruptableMockTransport transport1 = new DisruptableMockTransport(node1, deterministicTaskQueue) {
             @Override
             protected ConnectionStatus getConnectionStatus(DiscoveryNode destination) {
                 return DisruptableMockTransportTests.this.getConnectionStatus(getLocalNode(), destination);
@@ -106,11 +111,11 @@ public class DisruptableMockTransportTests extends ESTestCase {
 
             @Override
             protected void execute(Runnable runnable) {
-                deterministicTaskQueue.scheduleNow(runnable);
+                deterministicTaskQueue.scheduleNow(unlessRebooted(node1, runnable));
             }
         };
 
-        final DisruptableMockTransport transport2 = new DisruptableMockTransport(node2, logger, deterministicTaskQueue) {
+        final DisruptableMockTransport transport2 = new DisruptableMockTransport(node2, deterministicTaskQueue) {
             @Override
             protected ConnectionStatus getConnectionStatus(DiscoveryNode destination) {
                 return DisruptableMockTransportTests.this.getConnectionStatus(getLocalNode(), destination);
@@ -123,7 +128,7 @@ public class DisruptableMockTransportTests extends ESTestCase {
 
             @Override
             protected void execute(Runnable runnable) {
-                deterministicTaskQueue.scheduleNow(runnable);
+                deterministicTaskQueue.scheduleNow(unlessRebooted(node2, runnable));
             }
         };
 
@@ -159,42 +164,96 @@ public class DisruptableMockTransportTests extends ESTestCase {
         assertTrue(fut2.isDone());
 
         deliverBlackholedRequests = () -> transports.forEach(DisruptableMockTransport::deliverBlackholedRequests);
+
+        blockTestAction = new Runnable() {
+            @Override
+            public void run() {
+                transports.forEach(t -> t.addActionBlock(TEST_ACTION));
+            }
+
+            @Override
+            public String toString() {
+                return "add block for " + TEST_ACTION;
+            }
+        };
+
+        activeRequestCount = 0;
+        rebootedNodes.clear();
+    }
+
+    @After
+    public void assertAllRequestsReleased() {
+        assertEquals(0, activeRequestCount);
+    }
+
+    private Runnable reboot(DiscoveryNode discoveryNode) {
+        return new Runnable() {
+            @Override
+            public void run() {
+                rebootedNodes.add(discoveryNode);
+            }
+
+            @Override
+            public String toString() {
+                return "reboot " + discoveryNode;
+            }
+        };
+    }
+
+    private Runnable unlessRebooted(DiscoveryNode discoveryNode, Runnable runnable) {
+        return new Runnable() {
+            @Override
+            public void run() {
+                if (rebootedNodes.contains(discoveryNode)) {
+                    if (runnable instanceof DisruptableMockTransport.RebootSensitiveRunnable rebootSensitiveRunnable) {
+                        rebootSensitiveRunnable.ifRebooted();
+                    }
+                } else {
+                    runnable.run();
+                }
+            }
+
+            @Override
+            public String toString() {
+                return "unlessRebooted[" + discoveryNode.getId() + "/" + runnable + "]";
+            }
+        };
     }
 
-    private TransportRequestHandler<TransportRequest.Empty> requestHandlerShouldNotBeCalled() {
+    private TransportRequestHandler<TestRequest> requestHandlerShouldNotBeCalled() {
         return (request, channel, task) -> { throw new AssertionError("should not be called"); };
     }
 
-    private TransportRequestHandler<TransportRequest.Empty> requestHandlerRepliesNormally() {
+    private TransportRequestHandler<TestRequest> requestHandlerRepliesNormally() {
         return (request, channel, task) -> {
             logger.debug("got a dummy request, replying normally...");
-            channel.sendResponse(TransportResponse.Empty.INSTANCE);
+            channel.sendResponse(new TestResponse());
         };
     }
 
-    private TransportRequestHandler<TransportRequest.Empty> requestHandlerRepliesExceptionally(Exception e) {
+    private TransportRequestHandler<TestRequest> requestHandlerRepliesExceptionally(Exception e) {
         return (request, channel, task) -> {
             logger.debug("got a dummy request, replying exceptionally...");
             channel.sendResponse(e);
         };
     }
 
-    private TransportRequestHandler<TransportRequest.Empty> requestHandlerCaptures(Consumer<TransportChannel> channelConsumer) {
+    private TransportRequestHandler<TestRequest> requestHandlerCaptures(Consumer<TransportChannel> channelConsumer) {
         return (request, channel, task) -> {
             logger.debug("got a dummy request...");
             channelConsumer.accept(channel);
         };
     }
 
-    private TransportResponseHandler<TransportResponse> responseHandlerShouldNotBeCalled() {
+    private <T extends TransportResponse> TransportResponseHandler<T> responseHandlerShouldNotBeCalled() {
         return new TransportResponseHandler<>() {
             @Override
-            public TransportResponse read(StreamInput in) {
+            public T read(StreamInput in) {
                 throw new AssertionError("should not be called");
             }
 
             @Override
-            public void handleResponse(TransportResponse response) {
+            public void handleResponse(T response) {
                 throw new AssertionError("should not be called");
             }
 
@@ -205,10 +264,15 @@ public class DisruptableMockTransportTests extends ESTestCase {
         };
     }
 
-    private TransportResponseHandler<TransportResponse.Empty> responseHandlerShouldBeCalledNormally(Runnable onCalled) {
-        return new TransportResponseHandler.Empty() {
+    private TransportResponseHandler<TestResponse> responseHandlerShouldBeCalledNormally(Runnable onCalled) {
+        return new TransportResponseHandler<>() {
+            @Override
+            public TestResponse read(StreamInput in) throws IOException {
+                return new TestResponse(in);
+            }
+
             @Override
-            public void handleResponse(TransportResponse.Empty response) {
+            public void handleResponse(TestResponse response) {
                 onCalled.run();
             }
 
@@ -219,15 +283,17 @@ public class DisruptableMockTransportTests extends ESTestCase {
         };
     }
 
-    private TransportResponseHandler<TransportResponse> responseHandlerShouldBeCalledExceptionally(Consumer<TransportException> onCalled) {
+    private <T extends TransportResponse> TransportResponseHandler<T> responseHandlerShouldBeCalledExceptionally(
+        Consumer<TransportException> onCalled
+    ) {
         return new TransportResponseHandler<>() {
             @Override
-            public TransportResponse read(StreamInput in) {
+            public T read(StreamInput in) {
                 throw new AssertionError("should not be called");
             }
 
             @Override
-            public void handleResponse(TransportResponse response) {
+            public void handleResponse(T response) {
                 throw new AssertionError("should not be called");
             }
 
@@ -238,8 +304,8 @@ public class DisruptableMockTransportTests extends ESTestCase {
         };
     }
 
-    private void registerRequestHandler(TransportService transportService, TransportRequestHandler<TransportRequest.Empty> handler) {
-        transportService.registerRequestHandler("internal:dummy", ThreadPool.Names.GENERIC, TransportRequest.Empty::new, handler);
+    private void registerRequestHandler(TransportService transportService, TransportRequestHandler<TestRequest> handler) {
+        transportService.registerRequestHandler(TEST_ACTION, ThreadPool.Names.GENERIC, TestRequest::new, handler);
     }
 
     private void send(
@@ -247,7 +313,12 @@ public class DisruptableMockTransportTests extends ESTestCase {
         DiscoveryNode destinationNode,
         TransportResponseHandler<? extends TransportResponse> responseHandler
     ) {
-        transportService.sendRequest(destinationNode, "internal:dummy", TransportRequest.Empty.INSTANCE, responseHandler);
+        final var request = new TestRequest();
+        try {
+            transportService.sendRequest(destinationNode, TEST_ACTION, request, responseHandler);
+        } finally {
+            request.decRef();
+        }
     }
 
     public void testSuccessfulResponse() {
@@ -259,6 +330,18 @@ public class DisruptableMockTransportTests extends ESTestCase {
         assertTrue(responseHandlerCalled.get());
     }
 
+    public void testBlockedAction() {
+        registerRequestHandler(service1, requestHandlerShouldNotBeCalled());
+        registerRequestHandler(service2, requestHandlerRepliesNormally());
+        blockTestAction.run();
+        AtomicReference<TransportException> responseHandlerException = new AtomicReference<>();
+        send(service1, node2, responseHandlerShouldBeCalledExceptionally(responseHandlerException::set));
+        deterministicTaskQueue.runAllRunnableTasks();
+        assertNotNull(responseHandlerException.get());
+        assertNotNull(responseHandlerException.get().getCause());
+        assertThat(responseHandlerException.get().getCause().getMessage(), containsString("action [" + TEST_ACTION + "] is blocked"));
+    }
+
     public void testExceptionalResponse() {
         registerRequestHandler(service1, requestHandlerShouldNotBeCalled());
         Exception e = new Exception("dummy exception");
@@ -310,7 +393,7 @@ public class DisruptableMockTransportTests extends ESTestCase {
         assertNull(responseHandlerException.get());
 
         disconnectedLinks.add(Tuple.tuple(node2, node1));
-        responseHandlerChannel.get().sendResponse(TransportResponse.Empty.INSTANCE);
+        responseHandlerChannel.get().sendResponse(new TestResponse());
         deterministicTaskQueue.runAllTasks();
         deliverBlackholedRequests.run();
         deterministicTaskQueue.runAllTasks();
@@ -348,7 +431,7 @@ public class DisruptableMockTransportTests extends ESTestCase {
         assertNotNull(responseHandlerChannel.get());
 
         blackholedLinks.add(Tuple.tuple(node2, node1));
-        responseHandlerChannel.get().sendResponse(TransportResponse.Empty.INSTANCE);
+        responseHandlerChannel.get().sendResponse(new TestResponse());
         deterministicTaskQueue.runAllRunnableTasks();
     }
 
@@ -380,7 +463,7 @@ public class DisruptableMockTransportTests extends ESTestCase {
 
         blackholedRequestLinks.add(Tuple.tuple(node1, node2));
         blackholedRequestLinks.add(Tuple.tuple(node2, node1));
-        responseHandlerChannel.get().sendResponse(TransportResponse.Empty.INSTANCE);
+        responseHandlerChannel.get().sendResponse(new TestResponse());
 
         deterministicTaskQueue.runAllRunnableTasks();
         assertTrue(responseHandlerCalled.get());
@@ -406,6 +489,73 @@ public class DisruptableMockTransportTests extends ESTestCase {
         assertTrue(responseHandlerCalled.get());
     }
 
+    public void testResponseWithReboots() {
+        registerRequestHandler(service1, requestHandlerShouldNotBeCalled());
+        registerRequestHandler(
+            service2,
+            randomFrom(requestHandlerRepliesNormally(), requestHandlerRepliesExceptionally(new ElasticsearchException("simulated")))
+        );
+
+        final var linkDisruptions = List.of(
+            Tuple.tuple("blackhole", blackholedLinks),
+            Tuple.tuple("blackhole-request", blackholedRequestLinks),
+            Tuple.tuple("disconnected", disconnectedLinks)
+        );
+
+        for (Runnable runnable : Stream.concat(
+            Stream.of(reboot(node1), reboot(node2), blockTestAction),
+            Stream.of(Tuple.tuple(node1, node2), Tuple.tuple(node2, node2)).map(link -> {
+                final var disruption = randomFrom(linkDisruptions);
+                return new Runnable() {
+                    @Override
+                    public void run() {
+                        if (linkDisruptions.stream().noneMatch(otherDisruption -> otherDisruption.v2().contains(link))) {
+                            disruption.v2().add(link);
+                        }
+                    }
+
+                    @Override
+                    public String toString() {
+                        return disruption.v1() + ": " + link.v1().getId() + " to " + link.v2().getId();
+                    }
+                };
+            })
+        ).toList()) {
+            if (randomBoolean()) {
+                deterministicTaskQueue.scheduleNow(runnable);
+            }
+        }
+
+        AtomicBoolean responseHandlerCalled = new AtomicBoolean();
+        AtomicBoolean responseHandlerReleased = new AtomicBoolean();
+        deterministicTaskQueue.scheduleNow(new Runnable() {
+            @Override
+            public void run() {
+                DisruptableMockTransportTests.this.send(
+                    service1,
+                    node2,
+                    new CleanableResponseHandler<>(
+                        ActionListener.wrap(() -> assertFalse(responseHandlerCalled.getAndSet(true))),
+                        TestResponse::new,
+                        ThreadPool.Names.SAME,
+                        () -> assertFalse(responseHandlerReleased.getAndSet(true))
+                    )
+                );
+            }
+
+            @Override
+            public String toString() {
+                return "send test message";
+            }
+        });
+
+        deterministicTaskQueue.runAllRunnableTasks();
+        deliverBlackholedRequests.run();
+        deterministicTaskQueue.runAllRunnableTasks();
+
+        assertTrue(responseHandlerReleased.get());
+    }
+
     public void testBrokenLinkFailsToConnect() {
         service1.disconnectFromNode(node2);
 
@@ -440,4 +590,77 @@ public class DisruptableMockTransportTests extends ESTestCase {
             endsWith("does not exist")
         );
     }
+
+    private class TestRequest extends TransportRequest {
+        private final RefCounted refCounted;
+
+        TestRequest() {
+            activeRequestCount++;
+            refCounted = AbstractRefCounted.of(() -> activeRequestCount--);
+        }
+
+        TestRequest(StreamInput in) throws IOException {
+            super(in);
+            activeRequestCount++;
+            refCounted = AbstractRefCounted.of(() -> activeRequestCount--);
+        }
+
+        @Override
+        public void incRef() {
+            refCounted.incRef();
+        }
+
+        @Override
+        public boolean tryIncRef() {
+            return refCounted.tryIncRef();
+        }
+
+        @Override
+        public boolean decRef() {
+            return refCounted.decRef();
+        }
+
+        @Override
+        public boolean hasReferences() {
+            return refCounted.hasReferences();
+        }
+    }
+
+    private class TestResponse extends TransportResponse {
+        private final RefCounted refCounted;
+
+        TestResponse() {
+            activeRequestCount++;
+            refCounted = AbstractRefCounted.of(() -> activeRequestCount--);
+        }
+
+        TestResponse(StreamInput in) throws IOException {
+            super(in);
+            activeRequestCount++;
+            refCounted = AbstractRefCounted.of(() -> activeRequestCount--);
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) {}
+
+        @Override
+        public void incRef() {
+            refCounted.incRef();
+        }
+
+        @Override
+        public boolean tryIncRef() {
+            return refCounted.tryIncRef();
+        }
+
+        @Override
+        public boolean decRef() {
+            return refCounted.decRef();
+        }
+
+        @Override
+        public boolean hasReferences() {
+            return refCounted.hasReferences();
+        }
+    }
 }