1
0
Эх сурвалжийг харах

Recycle pages used by outgoing publications (#77317)

Today `PublicationTransportHandler.PublicationContext` allocates a bunch
of memory for serialized cluster states and diffs, but it uses a plain
`BytesStreamOutput` which means that the backing pages are allocated by
the `BigArrays#NON_RECYCLING_INSTANCE`. With this commit we pass in a
proper `BigArrays` so that the pages being used can be recycled.
David Turner 4 жил өмнө
parent
commit
8b50fcd3d6

+ 48 - 25
server/src/main/java/org/elasticsearch/cluster/coordination/Coordinator.java

@@ -34,6 +34,7 @@ import org.elasticsearch.cluster.routing.allocation.AllocationService;
 import org.elasticsearch.cluster.service.ClusterApplier;
 import org.elasticsearch.cluster.service.ClusterApplier.ClusterApplyListener;
 import org.elasticsearch.cluster.service.MasterService;
+import org.elasticsearch.common.util.BigArrays;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.common.Priority;
 import org.elasticsearch.common.Strings;
@@ -147,11 +148,24 @@ public class Coordinator extends AbstractLifecycleComponent implements Discovery
      * @param nodeName The name of the node, used to name the {@link java.util.concurrent.ExecutorService} of the {@link SeedHostsResolver}.
      * @param onJoinValidators A collection of join validators to restrict which nodes may join the cluster.
      */
-    public Coordinator(String nodeName, Settings settings, ClusterSettings clusterSettings, TransportService transportService,
-                       NamedWriteableRegistry namedWriteableRegistry, AllocationService allocationService, MasterService masterService,
-                       Supplier<CoordinationState.PersistedState> persistedStateSupplier, SeedHostsProvider seedHostsProvider,
-                       ClusterApplier clusterApplier, Collection<BiConsumer<DiscoveryNode, ClusterState>> onJoinValidators, Random random,
-                       RerouteService rerouteService, ElectionStrategy electionStrategy, NodeHealthService nodeHealthService) {
+    public Coordinator(
+        String nodeName,
+        Settings settings,
+        ClusterSettings clusterSettings,
+        BigArrays bigArrays,
+        TransportService transportService,
+        NamedWriteableRegistry namedWriteableRegistry,
+        AllocationService allocationService,
+        MasterService masterService,
+        Supplier<CoordinationState.PersistedState> persistedStateSupplier,
+        SeedHostsProvider seedHostsProvider,
+        ClusterApplier clusterApplier,
+        Collection<BiConsumer<DiscoveryNode, ClusterState>> onJoinValidators,
+        Random random,
+        RerouteService rerouteService,
+        ElectionStrategy electionStrategy,
+        NodeHealthService nodeHealthService
+    ) {
         this.settings = settings;
         this.transportService = transportService;
         this.masterService = masterService;
@@ -176,8 +190,13 @@ public class Coordinator extends AbstractLifecycleComponent implements Discovery
         configuredHostsResolver = new SeedHostsResolver(nodeName, settings, transportService, seedHostsProvider);
         this.peerFinder = new CoordinatorPeerFinder(settings, transportService,
             new HandshakingTransportAddressConnector(settings, transportService), configuredHostsResolver);
-        this.publicationHandler = new PublicationTransportHandler(transportService, namedWriteableRegistry,
-            this::handlePublishRequest, this::handleApplyCommit);
+        this.publicationHandler = new PublicationTransportHandler(
+            bigArrays,
+            transportService,
+            namedWriteableRegistry,
+            this::handlePublishRequest,
+            this::handleApplyCommit
+        );
         this.leaderChecker = new LeaderChecker(settings, transportService, this::onLeaderFailure, nodeHealthService);
         this.followersChecker = new FollowersChecker(settings, transportService, this::onFollowerCheckRequest, this::removeNode,
             nodeHealthService);
@@ -1071,24 +1090,28 @@ public class Coordinator extends AbstractLifecycleComponent implements Discovery
                 final long publicationContextConstructionStartMillis = transportService.getThreadPool().rawRelativeTimeInMillis();
                 final PublicationTransportHandler.PublicationContext publicationContext =
                     publicationHandler.newPublicationContext(clusterStatePublicationEvent);
-                clusterStatePublicationEvent.setPublicationContextConstructionElapsedMillis(
-                    transportService.getThreadPool().rawRelativeTimeInMillis() - publicationContextConstructionStartMillis);
-
-                final PublishRequest publishRequest = coordinationState.get().handleClientValue(clusterState);
-                final CoordinatorPublication publication = new CoordinatorPublication(
-                    clusterStatePublicationEvent,
-                    publishRequest,
-                    publicationContext,
-                    new ListenableFuture<>(),
-                    ackListener,
-                    publishListener);
-                currentPublication = Optional.of(publication);
-
-                final DiscoveryNodes publishNodes = publishRequest.getAcceptedState().nodes();
-                leaderChecker.setCurrentNodes(publishNodes);
-                followersChecker.setCurrentNodes(publishNodes);
-                lagDetector.setTrackedNodes(publishNodes);
-                publication.start(followersChecker.getFaultyNodes());
+                try {
+                    clusterStatePublicationEvent.setPublicationContextConstructionElapsedMillis(
+                        transportService.getThreadPool().rawRelativeTimeInMillis() - publicationContextConstructionStartMillis);
+
+                    final PublishRequest publishRequest = coordinationState.get().handleClientValue(clusterState);
+                    final CoordinatorPublication publication = new CoordinatorPublication(
+                        clusterStatePublicationEvent,
+                        publishRequest,
+                        publicationContext,
+                        new ListenableFuture<>(),
+                        ackListener,
+                        publishListener);
+                    currentPublication = Optional.of(publication);
+
+                    final DiscoveryNodes publishNodes = publishRequest.getAcceptedState().nodes();
+                    leaderChecker.setCurrentNodes(publishNodes);
+                    followersChecker.setCurrentNodes(publishNodes);
+                    lagDetector.setTrackedNodes(publishNodes);
+                    publication.start(followersChecker.getFaultyNodes());
+                } finally {
+                    publicationContext.decRef();
+                }
             }
         } catch (Exception e) {
             logger.debug(() -> new ParameterizedMessage("[{}] publishing failed", clusterStatePublicationEvent.getSummary()), e);

+ 170 - 144
server/src/main/java/org/elasticsearch/cluster/coordination/PublicationTransportHandler.java

@@ -13,39 +13,45 @@ import org.apache.logging.log4j.message.ParameterizedMessage;
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.cluster.ClusterStatePublicationEvent;
+import org.elasticsearch.action.ActionListenerResponseHandler;
+import org.elasticsearch.action.support.ChannelActionListener;
 import org.elasticsearch.cluster.ClusterState;
+import org.elasticsearch.cluster.ClusterStatePublicationEvent;
 import org.elasticsearch.cluster.Diff;
 import org.elasticsearch.cluster.IncompatibleClusterStateVersionException;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNodes;
-import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.common.bytes.ReleasableBytesReference;
 import org.elasticsearch.common.compress.Compressor;
 import org.elasticsearch.common.compress.CompressorFactory;
+import org.elasticsearch.common.io.Streams;
 import org.elasticsearch.common.io.stream.BytesStreamOutput;
 import org.elasticsearch.common.io.stream.InputStreamStreamInput;
 import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.OutputStreamStreamOutput;
+import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.common.util.LazyInitializable;
+import org.elasticsearch.core.AbstractRefCounted;
+import org.elasticsearch.core.Releasables;
 import org.elasticsearch.core.internal.io.IOUtils;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.BytesTransportRequest;
-import org.elasticsearch.transport.TransportChannel;
 import org.elasticsearch.transport.TransportException;
 import org.elasticsearch.transport.TransportRequestOptions;
 import org.elasticsearch.transport.TransportResponse;
-import org.elasticsearch.transport.TransportResponseHandler;
 import org.elasticsearch.transport.TransportService;
 
 import java.io.IOException;
 import java.util.HashMap;
 import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.BiConsumer;
-import java.util.function.Consumer;
 import java.util.function.Function;
 
 public class PublicationTransportHandler {
@@ -55,6 +61,7 @@ public class PublicationTransportHandler {
     public static final String PUBLISH_STATE_ACTION_NAME = "internal:cluster/coordination/publish_state";
     public static final String COMMIT_STATE_ACTION_NAME = "internal:cluster/coordination/commit_state";
 
+    private final BigArrays bigArrays;
     private final TransportService transportService;
     private final NamedWriteableRegistry namedWriteableRegistry;
     private final Function<PublishRequest, PublishWithJoinResponse> handlePublishRequest;
@@ -76,9 +83,14 @@ public class PublicationTransportHandler {
     private static final TransportRequestOptions STATE_REQUEST_OPTIONS =
             TransportRequestOptions.of(null, TransportRequestOptions.Type.STATE);
 
-    public PublicationTransportHandler(TransportService transportService, NamedWriteableRegistry namedWriteableRegistry,
-                                       Function<PublishRequest, PublishWithJoinResponse> handlePublishRequest,
-                                       BiConsumer<ApplyCommitRequest, ActionListener<Void>> handleApplyCommit) {
+    public PublicationTransportHandler(
+        BigArrays bigArrays,
+        TransportService transportService,
+        NamedWriteableRegistry namedWriteableRegistry,
+        Function<PublishRequest, PublishWithJoinResponse> handlePublishRequest,
+        BiConsumer<ApplyCommitRequest, ActionListener<Void>> handleApplyCommit
+    ) {
+        this.bigArrays = bigArrays;
         this.transportService = transportService;
         this.namedWriteableRegistry = namedWriteableRegistry;
         this.handlePublishRequest = handlePublishRequest;
@@ -88,31 +100,9 @@ public class PublicationTransportHandler {
 
         transportService.registerRequestHandler(COMMIT_STATE_ACTION_NAME, ThreadPool.Names.GENERIC, false, false,
             ApplyCommitRequest::new,
-            (request, channel, task) -> handleApplyCommit.accept(request, transportCommitCallback(channel)));
-    }
-
-    private ActionListener<Void> transportCommitCallback(TransportChannel channel) {
-        return new ActionListener<Void>() {
-
-            @Override
-            public void onResponse(Void aVoid) {
-                try {
-                    channel.sendResponse(TransportResponse.Empty.INSTANCE);
-                } catch (IOException e) {
-                    logger.debug("failed to send response on commit", e);
-                }
-            }
-
-            @Override
-            public void onFailure(Exception e) {
-                try {
-                    channel.sendResponse(e);
-                } catch (IOException ie) {
-                    e.addSuppressed(ie);
-                    logger.debug("failed to send response on commit", e);
-                }
-            }
-        };
+            (request, channel, task) -> handleApplyCommit.accept(
+                request,
+                new ChannelActionListener<>(channel, COMMIT_STATE_ACTION_NAME, request).map(r -> TransportResponse.Empty.INSTANCE)));
     }
 
     public PublishClusterStateStats stats() {
@@ -197,50 +187,96 @@ public class PublicationTransportHandler {
 
     public PublicationContext newPublicationContext(ClusterStatePublicationEvent clusterStatePublicationEvent) {
         final PublicationContext publicationContext = new PublicationContext(clusterStatePublicationEvent);
-
-        // Build the serializations we expect to need now, early in the process, so that an error during serialization fails the publication
-        // straight away. This isn't watertight since we send diffs on a best-effort basis and may fall back to sending a full state (and
-        // therefore serializing it) if the diff-based publication fails.
-        publicationContext.buildDiffAndSerializeStates();
-        return publicationContext;
+        boolean success = false;
+        try {
+            // Build the serializations we expect to need now, early in the process, so that an error during serialization fails the
+            // publication straight away. This isn't watertight since we send diffs on a best-effort basis and may fall back to sending a
+            // full state (and therefore serializing it) if the diff-based publication fails.
+            publicationContext.buildDiffAndSerializeStates();
+            success = true;
+            return publicationContext;
+        } finally {
+            if (success == false) {
+                publicationContext.decRef();
+            }
+        }
     }
 
-    private static BytesReference serializeFullClusterState(ClusterState clusterState, Version nodeVersion) throws IOException {
-        final BytesStreamOutput bStream = new BytesStreamOutput();
-        try (StreamOutput stream = new OutputStreamStreamOutput(CompressorFactory.COMPRESSOR.threadLocalOutputStream(bStream))) {
-            stream.setVersion(nodeVersion);
-            stream.writeBoolean(true);
-            clusterState.writeTo(stream);
+    private ReleasableBytesReference serializeFullClusterState(ClusterState clusterState, DiscoveryNode node) {
+        final Version nodeVersion = node.getVersion();
+        final BytesStreamOutput bytesStream = new ReleasableBytesStreamOutput(bigArrays);
+        boolean success = false;
+        try {
+            try (StreamOutput stream = new OutputStreamStreamOutput(
+                CompressorFactory.COMPRESSOR.threadLocalOutputStream(Streams.flushOnCloseStream(bytesStream)))
+            ) {
+                stream.setVersion(nodeVersion);
+                stream.writeBoolean(true);
+                clusterState.writeTo(stream);
+            } catch (IOException e) {
+                throw new ElasticsearchException("failed to serialize cluster state for publishing to node {}", e, node);
+            }
+            final ReleasableBytesReference result = new ReleasableBytesReference(bytesStream.bytes(), bytesStream::close);
+            logger.trace(
+                "serialized full cluster state version [{}] for node version [{}] with size [{}]",
+                clusterState.version(),
+                nodeVersion,
+                result.length());
+            success = true;
+            return result;
+        } finally {
+            if (success == false) {
+                bytesStream.close();
+            }
         }
-        final BytesReference serializedState = bStream.bytes();
-        logger.trace("serialized full cluster state version [{}] for node version [{}] with size [{}]",
-            clusterState.version(), nodeVersion, serializedState.length());
-        return serializedState;
     }
 
-    private static BytesReference serializeDiffClusterState(Diff<ClusterState> diff, Version nodeVersion) throws IOException {
-        final BytesStreamOutput bStream = new BytesStreamOutput();
-        try (StreamOutput stream = new OutputStreamStreamOutput(CompressorFactory.COMPRESSOR.threadLocalOutputStream(bStream))) {
-            stream.setVersion(nodeVersion);
-            stream.writeBoolean(false);
-            diff.writeTo(stream);
+    private ReleasableBytesReference serializeDiffClusterState(long clusterStateVersion, Diff<ClusterState> diff, DiscoveryNode node) {
+        final Version nodeVersion = node.getVersion();
+        final BytesStreamOutput bytesStream = new ReleasableBytesStreamOutput(bigArrays);
+        boolean success = false;
+        try {
+            try (StreamOutput stream = new OutputStreamStreamOutput(
+                CompressorFactory.COMPRESSOR.threadLocalOutputStream(Streams.flushOnCloseStream(bytesStream)))
+            ) {
+                stream.setVersion(nodeVersion);
+                stream.writeBoolean(false);
+                diff.writeTo(stream);
+            } catch (IOException e) {
+                throw new ElasticsearchException("failed to serialize cluster state diff for publishing to node {}", e, node);
+            }
+            final ReleasableBytesReference result = new ReleasableBytesReference(bytesStream.bytes(), bytesStream::close);
+            logger.trace(
+                "serialized cluster state diff for version [{}] for node version [{}] with size [{}]",
+                clusterStateVersion,
+                nodeVersion,
+                result.length());
+            success = true;
+            return result;
+        } finally {
+            if (success == false) {
+                bytesStream.close();
+            }
         }
-        return bStream.bytes();
     }
 
     /**
      * Publishing a cluster state typically involves sending the same cluster state (or diff) to every node, so the work of diffing,
      * serializing, and compressing the state can be done once and the results shared across publish requests. The
-     * {@code PublicationContext} implements this sharing.
+     * {@code PublicationContext} implements this sharing. It's ref-counted: the initial reference is released by the coordinator when
+     * a state (or diff) has been sent to every node, every transmitted diff also holds a reference in case it needs to retry with a full
+     * state.
      */
-    public class PublicationContext {
+    public class PublicationContext extends AbstractRefCounted {
 
         private final DiscoveryNodes discoveryNodes;
         private final ClusterState newState;
         private final ClusterState previousState;
         private final boolean sendFullVersion;
-        private final Map<Version, BytesReference> serializedStates = new HashMap<>();
-        private final Map<Version, BytesReference> serializedDiffs = new HashMap<>();
+
+        // All the values of these maps have one ref for the context (while it's open) and one for each in-flight message.
+        private final Map<Version, ReleasableBytesReference> serializedStates = new ConcurrentHashMap<>();
+        private final Map<Version, ReleasableBytesReference> serializedDiffs = new HashMap<>();
 
         PublicationContext(ClusterStatePublicationEvent clusterStatePublicationEvent) {
             discoveryNodes = clusterStatePublicationEvent.getNewState().nodes();
@@ -250,33 +286,25 @@ public class PublicationTransportHandler {
         }
 
         void buildDiffAndSerializeStates() {
-            Diff<ClusterState> diff = null;
+            assert refCount() > 0;
+            final LazyInitializable<Diff<ClusterState>, RuntimeException> diffSupplier
+                = new LazyInitializable<>(() -> newState.diff(previousState));
             for (DiscoveryNode node : discoveryNodes) {
-                try {
-                    if (sendFullVersion || previousState.nodes().nodeExists(node) == false) {
-                        if (serializedStates.containsKey(node.getVersion()) == false) {
-                            serializedStates.put(node.getVersion(), serializeFullClusterState(newState, node.getVersion()));
-                        }
-                    } else {
-                        // will send a diff
-                        if (diff == null) {
-                            diff = newState.diff(previousState);
-                        }
-                        if (serializedDiffs.containsKey(node.getVersion()) == false) {
-                            final BytesReference serializedDiff = serializeDiffClusterState(diff, node.getVersion());
-                            serializedDiffs.put(node.getVersion(), serializedDiff);
-                            logger.trace("serialized cluster state diff for version [{}] in for node version [{}] with size [{}]",
-                                newState.version(), node.getVersion(), serializedDiff.length());
-                        }
-                    }
-                } catch (IOException e) {
-                    throw new ElasticsearchException("failed to serialize cluster state for publishing to node {}", e, node);
+                if (sendFullVersion || previousState.nodes().nodeExists(node) == false) {
+                    serializedStates.computeIfAbsent(
+                        node.getVersion(),
+                        v -> serializeFullClusterState(newState, node));
+                } else {
+                    serializedDiffs.computeIfAbsent(
+                        node.getVersion(),
+                        v -> serializeDiffClusterState(newState.version(), diffSupplier.getOrCompute(), node));
                 }
             }
         }
 
         public void sendPublishRequest(DiscoveryNode destination, PublishRequest publishRequest,
                                        ActionListener<PublishWithJoinResponse> listener) {
+            assert refCount() > 0;
             assert publishRequest.getAcceptedState() == newState : "state got switched on us";
             assert transportService.getThreadPool().getThreadContext().isSystemContext();
             final ActionListener<PublishWithJoinResponse> responseActionListener;
@@ -314,37 +342,22 @@ public class PublicationTransportHandler {
         public void sendApplyCommit(DiscoveryNode destination, ApplyCommitRequest applyCommitRequest,
                                     ActionListener<TransportResponse.Empty> listener) {
             assert transportService.getThreadPool().getThreadContext().isSystemContext();
-            transportService.sendRequest(destination, COMMIT_STATE_ACTION_NAME, applyCommitRequest, STATE_REQUEST_OPTIONS,
-                new TransportResponseHandler<TransportResponse.Empty>() {
-
-                    @Override
-                    public TransportResponse.Empty read(StreamInput in) {
-                        return TransportResponse.Empty.INSTANCE;
-                    }
-
-                    @Override
-                    public void handleResponse(TransportResponse.Empty response) {
-                        listener.onResponse(response);
-                    }
-
-                    @Override
-                    public void handleException(TransportException exp) {
-                        listener.onFailure(exp);
-                    }
-
-                    @Override
-                    public String executor() {
-                        return ThreadPool.Names.GENERIC;
-                    }
-                });
+            transportService.sendRequest(
+                destination,
+                COMMIT_STATE_ACTION_NAME,
+                applyCommitRequest,
+                STATE_REQUEST_OPTIONS,
+                new ActionListenerResponseHandler<>(listener, in -> TransportResponse.Empty.INSTANCE, ThreadPool.Names.GENERIC));
         }
 
         private void sendFullClusterState(DiscoveryNode destination, ActionListener<PublishWithJoinResponse> listener) {
-            BytesReference bytes = serializedStates.get(destination.getVersion());
+            assert refCount() > 0;
+            ReleasableBytesReference bytes = serializedStates.get(destination.getVersion());
             if (bytes == null) {
                 try {
-                    bytes = serializeFullClusterState(newState, destination.getVersion());
-                    serializedStates.put(destination.getVersion(), bytes);
+                    bytes = serializedStates.computeIfAbsent(
+                        destination.getVersion(),
+                        v -> serializeFullClusterState(newState, destination));
                 } catch (Exception e) {
                     logger.warn(() -> new ParameterizedMessage(
                         "failed to serialize cluster state before publishing it to node {}", destination), e);
@@ -352,58 +365,71 @@ public class PublicationTransportHandler {
                     return;
                 }
             }
-            sendClusterState(destination, bytes, false, listener);
+            sendClusterState(destination, bytes, listener);
         }
 
         private void sendClusterStateDiff(DiscoveryNode destination, ActionListener<PublishWithJoinResponse> listener) {
-            final BytesReference bytes = serializedDiffs.get(destination.getVersion());
+            final ReleasableBytesReference bytes = serializedDiffs.get(destination.getVersion());
             assert bytes != null
                 : "failed to find serialized diff for node " + destination + " of version [" + destination.getVersion() + "]";
-            sendClusterState(destination, bytes, true, listener);
-        }
 
-        private void sendClusterState(DiscoveryNode destination, BytesReference bytes, boolean retryWithFullClusterStateOnFailure,
-                                      ActionListener<PublishWithJoinResponse> listener) {
-            try {
-                final BytesTransportRequest request = new BytesTransportRequest(bytes, destination.getVersion());
-                final Consumer<TransportException> transportExceptionHandler = exp -> {
-                    if (retryWithFullClusterStateOnFailure && exp.unwrapCause() instanceof IncompatibleClusterStateVersionException) {
-                        logger.debug("resending full cluster state to node {} reason {}", destination, exp.getDetailedMessage());
-                        sendFullClusterState(destination, listener);
-                    } else {
-                        logger.debug(() -> new ParameterizedMessage("failed to send cluster state to {}", destination), exp);
-                        listener.onFailure(exp);
+            // acquire a ref to the context just in case we need to try again with the full cluster state
+            if (tryIncRef() == false) {
+                assert false;
+                listener.onFailure(new IllegalStateException("publication context released before transmission"));
+                return;
+            }
+            sendClusterState(destination, bytes, ActionListener.runAfter(listener.delegateResponse((delegate, e) -> {
+                if (e instanceof TransportException) {
+                    final TransportException transportException = (TransportException) e;
+                    if (transportException.unwrapCause() instanceof IncompatibleClusterStateVersionException) {
+                        logger.debug(() -> new ParameterizedMessage(
+                            "resending full cluster state to node {} reason {}",
+                            destination,
+                            transportException.getDetailedMessage()));
+                        sendFullClusterState(destination, delegate);
+                        return;
                     }
-                };
-                final TransportResponseHandler<PublishWithJoinResponse> responseHandler =
-                    new TransportResponseHandler<PublishWithJoinResponse>() {
-
-                        @Override
-                        public PublishWithJoinResponse read(StreamInput in) throws IOException {
-                            return new PublishWithJoinResponse(in);
-                        }
-
-                        @Override
-                        public void handleResponse(PublishWithJoinResponse response) {
-                            listener.onResponse(response);
-                        }
+                }
 
-                        @Override
-                        public void handleException(TransportException exp) {
-                            transportExceptionHandler.accept(exp);
-                        }
+                logger.debug(new ParameterizedMessage("failed to send cluster state to {}", destination), e);
+                delegate.onFailure(e);
+            }), this::decRef));
+        }
 
-                        @Override
-                        public String executor() {
-                            return ThreadPool.Names.GENERIC;
-                        }
-                    };
-                transportService.sendRequest(destination, PUBLISH_STATE_ACTION_NAME, request, STATE_REQUEST_OPTIONS, responseHandler);
+        private void sendClusterState(
+            DiscoveryNode destination,
+            ReleasableBytesReference bytes,
+            ActionListener<PublishWithJoinResponse> listener
+        ) {
+            assert refCount() > 0;
+            if (bytes.tryIncRef() == false) {
+                assert false;
+                listener.onFailure(new IllegalStateException("serialized cluster state released before transmission"));
+                return;
+            }
+            try {
+                transportService.sendRequest(
+                    destination,
+                    PUBLISH_STATE_ACTION_NAME,
+                    new BytesTransportRequest(bytes, destination.getVersion()),
+                    STATE_REQUEST_OPTIONS,
+                    new ActionListenerResponseHandler<PublishWithJoinResponse>(
+                        ActionListener.runAfter(listener, bytes::decRef),
+                        PublishWithJoinResponse::new,
+                        ThreadPool.Names.GENERIC));
             } catch (Exception e) {
+                assert false : e;
                 logger.warn(() -> new ParameterizedMessage("error sending cluster state to {}", destination), e);
                 listener.onFailure(e);
             }
         }
+
+        @Override
+        protected void closeInternal() {
+            serializedDiffs.values().forEach(Releasables::closeExpectNoException);
+            serializedStates.values().forEach(Releasables::closeExpectNoException);
+        }
     }
 
 }

+ 1 - 0
server/src/main/java/org/elasticsearch/common/bytes/ReleasableBytesReference.java

@@ -49,6 +49,7 @@ public final class ReleasableBytesReference implements RefCounted, Releasable, B
     }
 
     public static ReleasableBytesReference wrap(BytesReference reference) {
+        assert reference instanceof ReleasableBytesReference == false : "use #retain() instead of #wrap() on a " + reference.getClass();
         return reference.length() == 0 ? empty() : new ReleasableBytesReference(reference, NO_OP);
     }
 

+ 34 - 10
server/src/main/java/org/elasticsearch/discovery/DiscoveryModule.java

@@ -26,6 +26,7 @@ import org.elasticsearch.common.settings.Setting;
 import org.elasticsearch.common.settings.Setting.Property;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.transport.TransportAddress;
+import org.elasticsearch.common.util.BigArrays;
 import org.elasticsearch.gateway.GatewayMetaState;
 import org.elasticsearch.monitor.NodeHealthService;
 import org.elasticsearch.plugins.DiscoveryPlugin;
@@ -71,11 +72,22 @@ public class DiscoveryModule {
 
     private final Discovery discovery;
 
-    public DiscoveryModule(Settings settings, TransportService transportService,
-                           NamedWriteableRegistry namedWriteableRegistry, NetworkService networkService, MasterService masterService,
-                           ClusterApplier clusterApplier, ClusterSettings clusterSettings, List<DiscoveryPlugin> plugins,
-                           AllocationService allocationService, Path configFile, GatewayMetaState gatewayMetaState,
-                           RerouteService rerouteService, NodeHealthService nodeHealthService) {
+    public DiscoveryModule(
+        Settings settings,
+        BigArrays bigArrays,
+        TransportService transportService,
+        NamedWriteableRegistry namedWriteableRegistry,
+        NetworkService networkService,
+        MasterService masterService,
+        ClusterApplier clusterApplier,
+        ClusterSettings clusterSettings,
+        List<DiscoveryPlugin> plugins,
+        AllocationService allocationService,
+        Path configFile,
+        GatewayMetaState gatewayMetaState,
+        RerouteService rerouteService,
+        NodeHealthService nodeHealthService
+    ) {
         final Collection<BiConsumer<DiscoveryNode, ClusterState>> joinValidators = new ArrayList<>();
         final Map<String, Supplier<SeedHostsProvider>> hostProviders = new HashMap<>();
         hostProviders.put("settings", () -> new SettingsBasedSeedHostsProvider(settings, transportService));
@@ -133,11 +145,23 @@ public class DiscoveryModule {
         }
 
         if (ZEN2_DISCOVERY_TYPE.equals(discoveryType) || SINGLE_NODE_DISCOVERY_TYPE.equals(discoveryType)) {
-            discovery = new Coordinator(NODE_NAME_SETTING.get(settings),
-                settings, clusterSettings,
-                transportService, namedWriteableRegistry, allocationService, masterService, gatewayMetaState::getPersistedState,
-                seedHostsProvider, clusterApplier, joinValidators, new Random(Randomness.get().nextLong()), rerouteService,
-                electionStrategy, nodeHealthService);
+            discovery = new Coordinator(
+                NODE_NAME_SETTING.get(settings),
+                settings,
+                clusterSettings,
+                bigArrays,
+                transportService,
+                namedWriteableRegistry,
+                allocationService,
+                masterService,
+                gatewayMetaState::getPersistedState,
+                seedHostsProvider,
+                clusterApplier,
+                joinValidators,
+                new Random(Randomness.get().nextLong()),
+                rerouteService,
+                electionStrategy,
+                nodeHealthService);
         } else {
             throw new IllegalArgumentException("Unknown discovery type [" + discoveryType + "]");
         }

+ 16 - 5
server/src/main/java/org/elasticsearch/node/Node.java

@@ -611,11 +611,22 @@ public class Node implements Closeable {
                 clusterService.getClusterSettings(), client, threadPool::relativeTimeInMillis, rerouteService);
             clusterInfoService.addListener(diskThresholdMonitor::onNewInfo);
 
-            final DiscoveryModule discoveryModule = new DiscoveryModule(settings, transportService, namedWriteableRegistry,
-                networkService, clusterService.getMasterService(), clusterService.getClusterApplierService(),
-                clusterService.getClusterSettings(), pluginsService.filterPlugins(DiscoveryPlugin.class),
-                clusterModule.getAllocationService(), environment.configFile(), gatewayMetaState, rerouteService,
-                fsHealthService);
+            final DiscoveryModule discoveryModule = new DiscoveryModule(
+                settings,
+                bigArrays,
+                transportService,
+                namedWriteableRegistry,
+                networkService,
+                clusterService.getMasterService(),
+                clusterService.getClusterApplierService(),
+                clusterService.getClusterSettings(),
+                pluginsService.filterPlugins(DiscoveryPlugin.class),
+                clusterModule.getAllocationService(),
+                environment.configFile(),
+                gatewayMetaState,
+                rerouteService,
+                fsHealthService
+            );
             this.nodeService = new NodeService(settings, threadPool, monitorService, discoveryModule.getDiscovery(),
                 transportService, indicesService, pluginsService, circuitBreakerService, scriptService,
                 httpServerTransport, ingestService, clusterService, settingsModule.getSettingsFilter(), responseCollectorService,

+ 2 - 2
server/src/main/java/org/elasticsearch/transport/BytesTransportRequest.java

@@ -32,8 +32,8 @@ public class BytesTransportRequest extends TransportRequest implements RefCounte
         version = in.getVersion();
     }
 
-    public BytesTransportRequest(BytesReference bytes, Version version) {
-        this.bytes = ReleasableBytesReference.wrap(bytes);
+    public BytesTransportRequest(ReleasableBytesReference bytes, Version version) {
+        this.bytes = bytes;
         this.version = version;
     }
 

+ 12 - 0
server/src/test/java/org/elasticsearch/cluster/coordination/CoordinatorTests.java

@@ -786,6 +786,8 @@ public class CoordinatorTests extends AbstractCoordinatorTestCase {
             cluster.stabilise(defaultMillis(PUBLISH_TIMEOUT_SETTING));
             assertTrue("expected eventual ack from " + leader, ackCollector.hasAckedSuccessfully(leader));
             assertFalse("expected no ack from " + follower0, ackCollector.hasAcked(follower0));
+
+            follower0.setClusterStateApplyResponse(ClusterStateApplyResponse.SUCCEED);
         }
     }
 
@@ -1388,6 +1390,10 @@ public class CoordinatorTests extends AbstractCoordinatorTestCase {
             cluster.bootstrapIfNecessary();
             cluster.runFor(10000, "failing join validation");
             assertTrue(cluster.clusterNodes.stream().allMatch(cn -> cn.getLastAppliedClusterState().version() == 0));
+
+            for (ClusterNode clusterNode : cluster.clusterNodes) {
+                clusterNode.extraJoinValidators.clear();
+            }
         }
     }
 
@@ -1565,6 +1571,8 @@ public class CoordinatorTests extends AbstractCoordinatorTestCase {
                 + 7 * delayVariabilityMillis, "stabilising");
 
             assertThat(cluster.getAnyLeader(), sameInstance(clusterNode));
+
+            cluster.deterministicTaskQueue.setExecutionDelayVariabilityMillis(DEFAULT_DELAY_VARIABILITY);
         }
     }
 
@@ -1705,6 +1713,10 @@ public class CoordinatorTests extends AbstractCoordinatorTestCase {
                     mockLogAppender.stop();
                 }
             }
+
+            for (ClusterNode clusterNode : cluster.clusterNodes) {
+                clusterNode.heal();
+            }
         }
     }
 

+ 16 - 4
server/src/test/java/org/elasticsearch/cluster/coordination/NodeJoinTests.java

@@ -26,9 +26,12 @@ import org.elasticsearch.cluster.service.MasterServiceTests;
 import org.elasticsearch.common.Randomness;
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.util.MockBigArrays;
+import org.elasticsearch.common.util.MockPageCacheRecycler;
 import org.elasticsearch.common.util.concurrent.BaseFuture;
 import org.elasticsearch.common.util.concurrent.DeterministicTaskQueue;
 import org.elasticsearch.common.util.concurrent.FutureUtils;
+import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
 import org.elasticsearch.monitor.NodeHealthService;
 import org.elasticsearch.monitor.StatusInfo;
 import org.elasticsearch.node.Node;
@@ -171,14 +174,23 @@ public class NodeJoinTests extends ESTestCase {
             TransportService.NOOP_TRANSPORT_INTERCEPTOR,
             x -> initialState.nodes().getLocalNode(),
             clusterSettings, Collections.emptySet());
-        coordinator = new Coordinator("test_node", Settings.EMPTY, clusterSettings,
-            transportService, writableRegistry(),
+        coordinator = new Coordinator(
+            "test_node",
+            Settings.EMPTY,
+            clusterSettings,
+            new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()),
+            transportService,
+            writableRegistry(),
             ESAllocationTestCase.createAllocationService(Settings.EMPTY),
             masterService,
-            () -> new InMemoryPersistedState(term, initialState), r -> emptyList(),
+            () -> new InMemoryPersistedState(term, initialState),
+            r -> emptyList(),
             new NoOpClusterApplier(),
             Collections.emptyList(),
-            random, (s, p, r) -> {}, ElectionStrategy.DEFAULT_INSTANCE, nodeHealthService);
+            random,
+            (s, p, r) -> {},
+            ElectionStrategy.DEFAULT_INSTANCE,
+            nodeHealthService);
         transportService.start();
         transportService.acceptIncomingRequests();
         transport = capturingTransport;

+ 240 - 20
server/src/test/java/org/elasticsearch/cluster/coordination/PublicationTransportHandlerTests.java

@@ -8,50 +8,73 @@
 package org.elasticsearch.cluster.coordination;
 
 import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.ExceptionsHelper;
 import org.elasticsearch.Version;
-import org.elasticsearch.cluster.ClusterStatePublicationEvent;
+import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.cluster.ClusterState;
+import org.elasticsearch.cluster.ClusterStatePublicationEvent;
 import org.elasticsearch.cluster.Diff;
+import org.elasticsearch.cluster.IncompatibleClusterStateVersionException;
 import org.elasticsearch.cluster.coordination.CoordinationMetadata.VotingConfiguration;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNodes;
+import org.elasticsearch.common.UUIDs;
+import org.elasticsearch.common.compress.Compressor;
+import org.elasticsearch.common.compress.CompressorFactory;
+import org.elasticsearch.common.io.stream.InputStreamStreamInput;
+import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Settings;
-import org.elasticsearch.common.util.concurrent.DeterministicTaskQueue;
+import org.elasticsearch.common.util.MockBigArrays;
+import org.elasticsearch.common.util.PageCacheRecycler;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.core.internal.io.IOUtils;
+import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
 import org.elasticsearch.test.ESTestCase;
-import org.elasticsearch.test.transport.CapturingTransport;
+import org.elasticsearch.test.VersionUtils;
+import org.elasticsearch.test.transport.MockTransport;
+import org.elasticsearch.threadpool.TestThreadPool;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.transport.BytesTransportRequest;
+import org.elasticsearch.transport.RemoteTransportException;
+import org.elasticsearch.transport.TransportRequest;
 import org.elasticsearch.transport.TransportService;
 
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.Collections;
+import java.util.List;
+import java.util.Optional;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
 
 import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.instanceOf;
+import static org.mockito.Mockito.mock;
 
 public class PublicationTransportHandlerTests extends ESTestCase {
 
     public void testDiffSerializationFailure() {
-        DeterministicTaskQueue deterministicTaskQueue = new DeterministicTaskQueue();
-        final ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
         final DiscoveryNode localNode = new DiscoveryNode("localNode", buildNewFakeTransportAddress(), Version.CURRENT);
-        final TransportService transportService = new CapturingTransport().createTransportService(Settings.EMPTY,
-            deterministicTaskQueue.getThreadPool(),
-            TransportService.NOOP_TRANSPORT_INTERCEPTOR,
-            x -> localNode,
-            clusterSettings, Collections.emptySet());
-        final PublicationTransportHandler handler = new PublicationTransportHandler(transportService,
-            writableRegistry(), pu -> null, (pu, l) -> {});
-        transportService.start();
-        transportService.acceptIncomingRequests();
+        final PublicationTransportHandler handler = new PublicationTransportHandler(
+            new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, new NoneCircuitBreakerService()),
+            mock(TransportService.class),
+            writableRegistry(),
+            pu -> null,
+            (pu, l) -> {});
 
         final DiscoveryNode otherNode = new DiscoveryNode("otherNode", buildNewFakeTransportAddress(), Version.CURRENT);
-        final ClusterState clusterState = CoordinationStateTests.clusterState(2L, 1L,
+        final ClusterState clusterState = CoordinationStateTests.clusterState(
+            2L,
+            1L,
             DiscoveryNodes.builder().add(localNode).add(otherNode).localNodeId(localNode.getId()).build(),
-            VotingConfiguration.EMPTY_CONFIG, VotingConfiguration.EMPTY_CONFIG, 0L);
+            VotingConfiguration.EMPTY_CONFIG,
+            VotingConfiguration.EMPTY_CONFIG,
+            0L);
 
-        final ClusterState unserializableClusterState = new ClusterState(clusterState.version(),
-            clusterState.stateUUID(), clusterState) {
+        final ClusterState unserializableClusterState = new ClusterState(clusterState.version(), clusterState.stateUUID(), clusterState) {
             @Override
             public Diff<ClusterState> diff(ClusterState previousState) {
                 return new Diff<ClusterState>() {
@@ -63,16 +86,213 @@ public class PublicationTransportHandlerTests extends ESTestCase {
 
                     @Override
                     public void writeTo(StreamOutput out) throws IOException {
+                        out.writeString("allocate something to detect leaks");
                         throw new IOException("Simulated failure of diff serialization");
                     }
                 };
             }
         };
 
-        ElasticsearchException e = expectThrows(ElasticsearchException.class, () ->
-            handler.newPublicationContext(new ClusterStatePublicationEvent("test", clusterState, unserializableClusterState, 0L, 0L)));
+        final ElasticsearchException e = expectThrows(
+            ElasticsearchException.class,
+            () -> handler.newPublicationContext(new ClusterStatePublicationEvent(
+                "test",
+                clusterState,
+                unserializableClusterState,
+                0L,
+                0L)));
         assertNotNull(e.getCause());
         assertThat(e.getCause(), instanceOf(IOException.class));
         assertThat(e.getCause().getMessage(), containsString("Simulated failure of diff serialization"));
     }
+
+    private static boolean isDiff(BytesTransportRequest request, DiscoveryNode node) {
+        try {
+            StreamInput in = null;
+            try {
+                in = request.bytes().streamInput();
+                final Compressor compressor = CompressorFactory.compressor(request.bytes());
+                if (compressor != null) {
+                    in = new InputStreamStreamInput(compressor.threadLocalInputStream(in));
+                }
+                in.setVersion(node.getVersion());
+                return in.readBoolean() == false;
+            } finally {
+                IOUtils.close(in);
+            }
+        } catch (IOException e) {
+            throw new AssertionError("unexpected", e);
+        }
+    }
+
+    public void testSerializationFailuresDoNotLeak() throws InterruptedException {
+        final ThreadPool threadPool = new TestThreadPool("test");
+        try {
+            threadPool.getThreadContext().markAsSystemContext();
+
+            final boolean simulateFailures = randomBoolean();
+            final DiscoveryNode localNode = new DiscoveryNode("localNode", buildNewFakeTransportAddress(), Version.CURRENT);
+            final MockTransport mockTransport = new MockTransport() {
+
+                @Nullable
+                private Exception simulateException(String action, BytesTransportRequest request, DiscoveryNode node) {
+                    if (action.equals(PublicationTransportHandler.PUBLISH_STATE_ACTION_NAME) && rarely()) {
+                        if (isDiff(request, node) && randomBoolean()) {
+                            return new IncompatibleClusterStateVersionException(
+                                randomNonNegativeLong(),
+                                UUIDs.randomBase64UUID(random()),
+                                randomNonNegativeLong(),
+                                UUIDs.randomBase64UUID(random()));
+                        }
+
+                        if (simulateFailures && randomBoolean()) {
+                            return new IOException("simulated failure");
+                        }
+                    }
+
+                    return null;
+                }
+
+                @Override
+                protected void onSendRequest(long requestId, String action, TransportRequest request, DiscoveryNode node) {
+                    final Exception exception = simulateException(action, (BytesTransportRequest) request, node);
+                    if (exception == null) {
+                        handleResponse(
+                            requestId,
+                            new PublishWithJoinResponse(
+                                new PublishResponse(randomNonNegativeLong(), randomNonNegativeLong()),
+                                Optional.empty()));
+                    } else {
+                        handleError(requestId, new RemoteTransportException(
+                            node.getName(),
+                            node.getAddress(),
+                            action,
+                            exception));
+                    }
+                }
+            };
+            final TransportService transportService = mockTransport.createTransportService(
+                Settings.EMPTY,
+                threadPool,
+                TransportService.NOOP_TRANSPORT_INTERCEPTOR,
+                x -> localNode,
+                new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS),
+                Collections.emptySet());
+            final PublicationTransportHandler handler = new PublicationTransportHandler(
+                new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, new NoneCircuitBreakerService()),
+                transportService,
+                writableRegistry(),
+                pu -> null,
+                (pu, l) -> {
+                });
+            transportService.start();
+            transportService.acceptIncomingRequests();
+
+            final List<DiscoveryNode> allNodes = new ArrayList<>();
+            while (allNodes.size() < 10) {
+                allNodes.add(new DiscoveryNode(
+                    "node-" + allNodes.size(),
+                    buildNewFakeTransportAddress(),
+                    VersionUtils.randomCompatibleVersion(random(), Version.CURRENT)));
+            }
+
+            final DiscoveryNodes.Builder prevNodes = DiscoveryNodes.builder();
+            prevNodes.add(localNode);
+            prevNodes.localNodeId(localNode.getId());
+            randomSubsetOf(allNodes).forEach(prevNodes::add);
+
+            final DiscoveryNodes.Builder nextNodes = DiscoveryNodes.builder();
+            nextNodes.add(localNode);
+            nextNodes.localNodeId(localNode.getId());
+            randomSubsetOf(allNodes).forEach(nextNodes::add);
+
+            final ClusterState prevClusterState = CoordinationStateTests.clusterState(
+                randomLongBetween(1L, Long.MAX_VALUE - 1),
+                randomNonNegativeLong(),
+                prevNodes.build(),
+                VotingConfiguration.EMPTY_CONFIG,
+                VotingConfiguration.EMPTY_CONFIG,
+                0L);
+
+            final ClusterState nextClusterState = new ClusterState(
+                randomNonNegativeLong(),
+                UUIDs.randomBase64UUID(random()),
+                CoordinationStateTests.clusterState(
+                    randomLongBetween(prevClusterState.term() + 1, Long.MAX_VALUE),
+                    randomNonNegativeLong(),
+                    nextNodes.build(),
+                    VotingConfiguration.EMPTY_CONFIG,
+                    VotingConfiguration.EMPTY_CONFIG,
+                    0L)) {
+
+                @Override
+                public void writeTo(StreamOutput out) throws IOException {
+                    if (simulateFailures && rarely()) {
+                        out.writeString("allocate something to detect leaks");
+                        throw new IOException("simulated failure");
+                    } else {
+                        super.writeTo(out);
+                    }
+                }
+
+                @Override
+                public Diff<ClusterState> diff(ClusterState previousState) {
+                    if (simulateFailures && rarely()) {
+                        return new Diff<ClusterState>() {
+                            @Override
+                            public ClusterState apply(ClusterState part) {
+                                fail("this diff shouldn't be applied");
+                                return part;
+                            }
+
+                            @Override
+                            public void writeTo(StreamOutput out) throws IOException {
+                                out.writeString("allocate something to detect leaks");
+                                throw new IOException("simulated failure");
+                            }
+                        };
+                    } else {
+                        return super.diff(previousState);
+                    }
+                }
+            };
+
+            final PublicationTransportHandler.PublicationContext context;
+            try {
+                context = handler.newPublicationContext(
+                    new ClusterStatePublicationEvent("test", prevClusterState, nextClusterState, 0L, 0L));
+            } catch (ElasticsearchException e) {
+                assertTrue(simulateFailures);
+                assertThat(e.getCause(), instanceOf(IOException.class));
+                assertThat(e.getCause().getMessage(), equalTo("simulated failure"));
+                return;
+            }
+
+            final CountDownLatch requestsLatch = new CountDownLatch(nextClusterState.nodes().getSize());
+            final CountDownLatch responsesLatch = new CountDownLatch(nextClusterState.nodes().getSize());
+
+            for (DiscoveryNode discoveryNode : nextClusterState.nodes()) {
+                threadPool.generic().execute(() -> {
+                    context.sendPublishRequest(
+                        discoveryNode,
+                        new PublishRequest(nextClusterState),
+                        ActionListener.runAfter(ActionListener.wrap(r -> {
+                        }, e -> {
+                            assert simulateFailures;
+                            final Throwable inner = ExceptionsHelper.unwrap(e, IOException.class);
+                            assert inner instanceof IOException : e;
+                            assertThat(inner.getMessage(), equalTo("simulated failure"));
+                        }), responsesLatch::countDown));
+                    requestsLatch.countDown();
+                });
+            }
+
+            assertTrue(requestsLatch.await(10, TimeUnit.SECONDS));
+            context.decRef();
+            assertTrue(responsesLatch.await(10, TimeUnit.SECONDS));
+        } finally {
+            assertTrue(ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS));
+        }
+    }
+
 }

+ 17 - 3
server/src/test/java/org/elasticsearch/discovery/DiscoveryModuleTests.java

@@ -18,6 +18,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.network.NetworkService;
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.util.BigArrays;
 import org.elasticsearch.core.internal.io.IOUtils;
 import org.elasticsearch.gateway.GatewayMetaState;
 import org.elasticsearch.plugins.DiscoveryPlugin;
@@ -74,9 +75,22 @@ public class DiscoveryModuleTests extends ESTestCase {
     }
 
     private DiscoveryModule newModule(Settings settings, List<DiscoveryPlugin> plugins) {
-        return new DiscoveryModule(settings, transportService, namedWriteableRegistry, null, masterService,
-            clusterApplier, clusterSettings, plugins, null, createTempDir().toAbsolutePath(), gatewayMetaState,
-            mock(RerouteService.class), null);
+        return new DiscoveryModule(
+            settings,
+            BigArrays.NON_RECYCLING_INSTANCE,
+            transportService,
+            namedWriteableRegistry,
+            null,
+            masterService,
+            clusterApplier,
+            clusterSettings,
+            plugins,
+            null,
+            createTempDir().toAbsolutePath(),
+            gatewayMetaState,
+            mock(RerouteService.class),
+            null
+        );
     }
 
     public void testDefaults() {

+ 1 - 0
server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java

@@ -2149,6 +2149,7 @@ public class SnapshotResiliencyTests extends ESTestCase {
                     node.getName(),
                     clusterService.getSettings(),
                     clusterService.getClusterSettings(),
+                    bigArrays,
                     transportService,
                     namedWriteableRegistry,
                     allocationService,

+ 163 - 13
test/framework/src/main/java/org/elasticsearch/cluster/coordination/AbstractCoordinatorTestCase.java

@@ -11,6 +11,7 @@ import com.carrotsearch.randomizedtesting.RandomizedContext;
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.apache.logging.log4j.message.ParameterizedMessage;
+import org.apache.lucene.util.BytesRef;
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.Version;
 import org.elasticsearch.cluster.ClusterModule;
@@ -41,6 +42,7 @@ import org.elasticsearch.common.settings.Setting;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.transport.TransportAddress;
 import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.common.util.ByteArray;
 import org.elasticsearch.common.util.MockBigArrays;
 import org.elasticsearch.common.util.MockPageCacheRecycler;
 import org.elasticsearch.common.util.concurrent.DeterministicTaskQueue;
@@ -454,6 +456,10 @@ public class AbstractCoordinatorTestCase extends ESTestCase {
                         final ClusterNode clusterNode = getAnyNode();
                         logger.debug("----> [runRandomly {}] applying initial configuration on {}", step, clusterNode.getId());
                         clusterNode.applyInitialConfiguration();
+                    } else if (rarely()) {
+                        final ClusterNode clusterNode = getAnyNode();
+                        logger.debug("----> [runRandomly {}] completing blackholed requests sent by {}", step, clusterNode.getId());
+                        clusterNode.deliverBlackholedRequests();
                     } else {
                         if (deterministicTaskQueue.hasDeferredTasks() && randomBoolean()) {
                             deterministicTaskQueue.advanceTime();
@@ -468,6 +474,8 @@ public class AbstractCoordinatorTestCase extends ESTestCase {
                 assertConsistentStates();
             }
 
+            logger.debug("delivering pending blackholed requests");
+            clusterNodes.forEach(ClusterNode::deliverBlackholedRequests);
             logger.debug("running {} cleanup actions", cleanupActions.size());
             cleanupActions.forEach(Runnable::run);
             logger.debug("finished running cleanup actions");
@@ -487,8 +495,9 @@ public class AbstractCoordinatorTestCase extends ESTestCase {
                 if (storedState == null) {
                     committedStatesByVersion.put(applierState.getVersion(), applierState);
                 } else {
-                    assertEquals("expected " + applierState + " but got " + storedState,
-                        value(applierState), value(storedState));
+                    if (value(applierState) != value(storedState)) {
+                        fail("expected " + applierState + " but got " + storedState);
+                    }
                 }
             }
         }
@@ -728,6 +737,12 @@ public class AbstractCoordinatorTestCase extends ESTestCase {
 
         @Override
         public void close() {
+            // noinspection ReplaceInefficientStreamCount using .count() to run the filter on every node
+            while (clusterNodes.stream().filter(ClusterNode::deliverBlackholedRequests).count() != 0L) {
+                logger.debug("--> stabilising again after delivering blackholed requests");
+                stabilise(DEFAULT_STABILISATION_TIME);
+            }
+
             clusterNodes.forEach(ClusterNode::close);
         }
 
@@ -908,7 +923,7 @@ public class AbstractCoordinatorTestCase extends ESTestCase {
             private DisruptableMockTransport mockTransport;
             private NodeHealthService nodeHealthService;
             List<BiConsumer<DiscoveryNode, ClusterState>> extraJoinValidators = new ArrayList<>();
-
+            private DelegatingBigArrays delegatingBigArrays;
 
             ClusterNode(int nodeIndex, boolean masterEligible, Settings nodeSettings, NodeHealthService nodeHealthService) {
                 this(nodeIndex, createDiscoveryNode(nodeIndex, masterEligible), defaultPersistedStateSupplier, nodeSettings,
@@ -970,10 +985,24 @@ public class AbstractCoordinatorTestCase extends ESTestCase {
                 final Collection<BiConsumer<DiscoveryNode, ClusterState>> onJoinValidators =
                     Collections.singletonList((dn, cs) -> extraJoinValidators.forEach(validator -> validator.accept(dn, cs)));
                 final AllocationService allocationService = ESAllocationTestCase.createAllocationService(Settings.EMPTY);
-                coordinator = new Coordinator("test_node", settings, clusterSettings, transportService, getNamedWriteableRegistry(),
-                    allocationService, masterService, this::getPersistedState,
-                    Cluster.this::provideSeedHosts, clusterApplierService, onJoinValidators, Randomness.get(), (s, p, r) -> {},
-                    getElectionStrategy(), nodeHealthService);
+                delegatingBigArrays = new DelegatingBigArrays(bigArrays);
+                coordinator = new Coordinator(
+                    "test_node",
+                    settings,
+                    clusterSettings,
+                    delegatingBigArrays,
+                    transportService,
+                    getNamedWriteableRegistry(),
+                    allocationService,
+                    masterService,
+                    this::getPersistedState,
+                    Cluster.this::provideSeedHosts,
+                    clusterApplierService,
+                    onJoinValidators,
+                    Randomness.get(),
+                    (s, p, r) -> {},
+                    getElectionStrategy(),
+                    nodeHealthService);
                 masterService.setClusterStatePublisher(coordinator);
                 final GatewayService gatewayService
                     = new GatewayService(settings, allocationService, clusterService, threadPool, coordinator, null);
@@ -1016,9 +1045,16 @@ public class AbstractCoordinatorTestCase extends ESTestCase {
                     address.address().getHostString(), address.getAddress(), address, Collections.emptyMap(),
                     localNode.isMasterNode() && DiscoveryNode.isMasterNode(nodeSettings)
                         ? allExceptVotingOnlyRole : emptySet(), Version.CURRENT);
-                return new ClusterNode(nodeIndex, newLocalNode,
-                    node -> new MockPersistedState(newLocalNode, persistedState, adaptGlobalMetadata, adaptCurrentTerm), nodeSettings,
-                    nodeHealthService);
+                try {
+                    return new ClusterNode(
+                        nodeIndex,
+                        newLocalNode,
+                        node -> new MockPersistedState(newLocalNode, persistedState, adaptGlobalMetadata, adaptCurrentTerm),
+                        nodeSettings,
+                        nodeHealthService);
+                } finally {
+                    delegatingBigArrays.releaseAll();
+                }
             }
 
             private CoordinationState.PersistedState getPersistedState() {
@@ -1060,11 +1096,17 @@ public class AbstractCoordinatorTestCase extends ESTestCase {
                 return new Runnable() {
                     @Override
                     public void run() {
-                        if (clusterNodes.contains(ClusterNode.this) == false) {
+                        if (clusterNodes.contains(ClusterNode.this)) {
+                            wrapped.run();
+                        } else if (runnable instanceof DisruptableMockTransport.RebootSensitiveRunnable) {
+                            logger.trace(
+                                "completing reboot-sensitive runnable {} from node {} as node has been removed from cluster",
+                                runnable,
+                                localNode);
+                            ((DisruptableMockTransport.RebootSensitiveRunnable) runnable).ifRebooted();
+                        } else {
                             logger.trace("ignoring runnable {} from node {} as node has been removed from cluster", runnable, localNode);
-                            return;
                         }
-                        wrapped.run();
                     }
 
                     @Override
@@ -1236,6 +1278,10 @@ public class AbstractCoordinatorTestCase extends ESTestCase {
             void allowClusterStateApplicationFailure() {
                 clusterApplierService.allowClusterStateApplicationFailure();
             }
+
+            boolean deliverBlackholedRequests() {
+                return mockTransport.deliverBlackholedRequests();
+            }
         }
 
         private List<TransportAddress> provideSeedHosts(SeedHostsProvider.HostsResolver ignored) {
@@ -1490,4 +1536,108 @@ public class AbstractCoordinatorTestCase extends ESTestCase {
         assertThat(spec.nextState(7, null, 42), equalTo(Optional.empty()));
     }
 
+    /**
+     * A wrapper around a {@link BigArrays} which tracks the arrays it allocates so that they can be released if the node reboots. Only
+     * works for {@link ByteArray} allocations since that's all the {@link Coordinator} needs.
+     */
+    static class DelegatingBigArrays extends BigArrays {
+
+        private final BigArrays delegate;
+
+        private final Set<DelegatingByteArray> trackedArrays = new HashSet<>();
+
+        DelegatingBigArrays(BigArrays delegate) {
+            super(null, null, null);
+            this.delegate = delegate;
+        }
+
+        @Override
+        public ByteArray newByteArray(long size, boolean clearOnResize) {
+            return track(delegate.newByteArray(size, clearOnResize));
+        }
+
+        @Override
+        public ByteArray resize(ByteArray array, long size) {
+            assert array instanceof DelegatingByteArray;
+            trackedArrays.remove(array);
+            return track(delegate.resize(((DelegatingByteArray) array).getDelegate(), size));
+        }
+
+        private ByteArray track(ByteArray byteArray) {
+            final DelegatingByteArray wrapped = new DelegatingByteArray(byteArray);
+            trackedArrays.add(wrapped);
+            return wrapped;
+        }
+
+        void releaseAll() {
+            for (DelegatingByteArray trackedArray : List.copyOf(trackedArrays)) {
+                trackedArray.close();
+            }
+            assert trackedArrays.isEmpty() : trackedArrays;
+        }
+
+        private class DelegatingByteArray implements ByteArray {
+
+            private final ByteArray delegate;
+
+            DelegatingByteArray(ByteArray delegate) {
+                this.delegate = delegate;
+            }
+
+            ByteArray getDelegate() {
+                return delegate;
+            }
+
+            @Override
+            public void close() {
+                delegate.close();
+                trackedArrays.remove(this);
+            }
+
+            @Override
+            public long size() {
+                return delegate.size();
+            }
+
+            @Override
+            public byte get(long index) {
+                return delegate.get(index);
+            }
+
+            @Override
+            public byte set(long index, byte value) {
+                return delegate.set(index, value);
+            }
+
+            @Override
+            public boolean get(long index, int len, BytesRef ref) {
+                return delegate.get(index, len, ref);
+            }
+
+            @Override
+            public void set(long index, byte[] buf, int offset, int len) {
+                delegate.set(index, buf, offset, len);
+            }
+
+            @Override
+            public void fill(long fromIndex, long toIndex, byte value) {
+                delegate.fill(fromIndex, toIndex, value);
+            }
+
+            @Override
+            public boolean hasArray() {
+                return delegate.hasArray();
+            }
+
+            @Override
+            public byte[] array() {
+                return delegate.array();
+            }
+
+            @Override
+            public long ramBytesUsed() {
+                return delegate.ramBytesUsed();
+            }
+        }
+    }
 }

+ 60 - 7
test/framework/src/main/java/org/elasticsearch/test/disruption/DisruptableMockTransport.java

@@ -22,6 +22,7 @@ import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.CloseableConnection;
 import org.elasticsearch.transport.ConnectTransportException;
 import org.elasticsearch.transport.ConnectionProfile;
+import org.elasticsearch.transport.NodeNotConnectedException;
 import org.elasticsearch.transport.RequestHandlerRegistry;
 import org.elasticsearch.transport.TransportChannel;
 import org.elasticsearch.transport.TransportException;
@@ -32,9 +33,10 @@ import org.elasticsearch.transport.TransportResponse;
 import org.elasticsearch.transport.TransportService;
 
 import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
 import java.util.Optional;
 import java.util.Set;
-import java.util.concurrent.TimeUnit;
 import java.util.function.Function;
 
 import static org.elasticsearch.test.ESTestCase.copyWriteable;
@@ -43,6 +45,7 @@ public abstract class DisruptableMockTransport extends MockTransport {
     private final DiscoveryNode localNode;
     private final Logger logger;
     private final DeterministicTaskQueue deterministicTaskQueue;
+    private final List<Runnable> blackholedRequests = new ArrayList<>();
 
     public DisruptableMockTransport(DiscoveryNode localNode, Logger logger, DeterministicTaskQueue deterministicTaskQueue) {
         this.localNode = localNode;
@@ -101,7 +104,7 @@ public abstract class DisruptableMockTransport extends MockTransport {
         assert destinationTransport.getLocalNode().equals(getLocalNode()) == false :
             "non-local message from " + getLocalNode() + " to itself";
 
-        destinationTransport.execute(new Runnable() {
+        destinationTransport.execute(new RebootSensitiveRunnable() {
             @Override
             public void run() {
                 final ConnectionStatus connectionStatus = getConnectionStatus(destinationTransport.getLocalNode());
@@ -124,8 +127,29 @@ public abstract class DisruptableMockTransport extends MockTransport {
                 }
             }
 
+            @Override
+            public void ifRebooted() {
+                deterministicTaskQueue.scheduleNow(() -> execute(new Runnable() {
+                    @Override
+                    public void run() {
+                        handleRemoteError(
+                            requestId,
+                            new NodeNotConnectedException(destinationTransport.getLocalNode(), "node rebooted"));
+                    }
+
+                    @Override
+                    public String toString() {
+                        return "error response (reboot) to " + internalToString();
+                    }
+                }));
+            }
+
             @Override
             public String toString() {
+                return internalToString();
+            }
+
+            private String internalToString() {
                 return getRequestDescription(requestId, action, destinationTransport.getLocalNode());
             }
         });
@@ -146,15 +170,23 @@ public abstract class DisruptableMockTransport extends MockTransport {
     }
 
     protected String getRequestDescription(long requestId, String action, DiscoveryNode destination) {
-        return new ParameterizedMessage("[{}][{}] from {} to {}",
-            requestId, action, getLocalNode(), destination).getFormattedMessage();
+        return new ParameterizedMessage("[{}][{}] from {} to {}", requestId, action, getLocalNode(), destination).getFormattedMessage();
     }
 
     protected void onBlackholedDuringSend(long requestId, String action, DisruptableMockTransport destinationTransport) {
         logger.trace("dropping {}", getRequestDescription(requestId, action, destinationTransport.getLocalNode()));
-        // Delaying the request for one day and then disconnect to simulate a very long pause
-        deterministicTaskQueue.scheduleAt(deterministicTaskQueue.getCurrentTimeMillis() + TimeUnit.DAYS.toMillis(1L),
-                () -> onDisconnectedDuringSend(requestId, action, destinationTransport));
+        // Delaying the response until explicitly instructed, to simulate a very long delay
+        blackholedRequests.add(new Runnable() {
+            @Override
+            public void run() {
+                onDisconnectedDuringSend(requestId, action, destinationTransport);
+            }
+
+            @Override
+            public String toString() {
+                return "deferred handling of dropped " + getRequestDescription(requestId, action, destinationTransport.getLocalNode());
+            }
+        });
     }
 
     protected void onDisconnectedDuringSend(long requestId, String action, DisruptableMockTransport destinationTransport) {
@@ -262,6 +294,16 @@ public abstract class DisruptableMockTransport extends MockTransport {
         }
     }
 
+    public boolean deliverBlackholedRequests() {
+        if (blackholedRequests.isEmpty()) {
+            return false;
+        } else {
+            blackholedRequests.forEach(deterministicTaskQueue::scheduleNow);
+            blackholedRequests.clear();
+            return true;
+        }
+    }
+
     /**
      * Response type from {@link DisruptableMockTransport#getConnectionStatus(DiscoveryNode)} indicating whether, and how, messages should
      * be disrupted on this transport.
@@ -287,4 +329,15 @@ public abstract class DisruptableMockTransport extends MockTransport {
          */
         BLACK_HOLE_REQUESTS_ONLY
     }
+
+    /**
+     * When simulating sending requests to another node which might have rebooted, it's not realistic just to drop the action if the node
+     * reboots; instead we need to simulate the error response that comes back.
+     */
+    public interface RebootSensitiveRunnable extends Runnable {
+        /**
+         * Cleanup action to run if the destination node reboots.
+         */
+        void ifRebooted();
+    }
 }

+ 10 - 0
test/framework/src/test/java/org/elasticsearch/test/disruption/DisruptableMockTransportTests.java

@@ -56,6 +56,8 @@ public class DisruptableMockTransportTests extends ESTestCase {
 
     private DeterministicTaskQueue deterministicTaskQueue;
 
+    private Runnable deliverBlackholedRequests;
+
     private Set<Tuple<DiscoveryNode, DiscoveryNode>> disconnectedLinks;
     private Set<Tuple<DiscoveryNode, DiscoveryNode>> blackholedLinks;
     private Set<Tuple<DiscoveryNode, DiscoveryNode>> blackholedRequestLinks;
@@ -142,6 +144,8 @@ public class DisruptableMockTransportTests extends ESTestCase {
         deterministicTaskQueue.runAllTasksInTimeOrder();
         assertTrue(fut1.isDone());
         assertTrue(fut2.isDone());
+
+        deliverBlackholedRequests = () -> transports.forEach(DisruptableMockTransport::deliverBlackholedRequests);
     }
 
     private TransportRequestHandler<TransportRequest.Empty> requestHandlerShouldNotBeCalled() {
@@ -294,6 +298,9 @@ public class DisruptableMockTransportTests extends ESTestCase {
         disconnectedLinks.add(Tuple.tuple(node2, node1));
         responseHandlerChannel.get().sendResponse(TransportResponse.Empty.INSTANCE);
         deterministicTaskQueue.runAllTasks();
+        deliverBlackholedRequests.run();
+        deterministicTaskQueue.runAllTasks();
+
         assertThat(responseHandlerException.get(), instanceOf(ConnectTransportException.class));
     }
 
@@ -311,6 +318,9 @@ public class DisruptableMockTransportTests extends ESTestCase {
         disconnectedLinks.add(Tuple.tuple(node2, node1));
         responseHandlerChannel.get().sendResponse(new Exception());
         deterministicTaskQueue.runAllTasks();
+        deliverBlackholedRequests.run();
+        deterministicTaskQueue.runAllTasks();
+
         assertThat(responseHandlerException.get(), instanceOf(ConnectTransportException.class));
     }