Ver código fonte

Reduce resource needs of join validation (#85380)

Fixes a few scalability issues around join validation:

- compresses the cluster state sent over the wire
- shares the serialized cluster state across multiple nodes
- forks the decompression/deserialization work off the transport thread

Relates #77466
Closes #83204
David Turner 3 anos atrás
pai
commit
79f181d208

+ 6 - 0
docs/changelog/85380.yaml

@@ -0,0 +1,6 @@
+pr: 85380
+summary: Reduce resource needs of join validation
+area: Cluster Coordination
+type: enhancement
+issues:
+ - 83204

+ 9 - 0
docs/reference/modules/discovery/discovery-settings.asciidoc

@@ -201,6 +201,15 @@ Sets how long the master node waits for each cluster state update to be
 completely published to all nodes, unless `discovery.type` is set to
 `single-node`. The default value is `30s`. See <<cluster-state-publishing>>.
 
+`cluster.join_validation.cache_timeout`::
+(<<static-cluster-setting,Static>>)
+When a node requests to join the cluster, the elected master node sends it a
+copy of a recent cluster state to detect certain problems which might prevent
+the new node from joining the cluster. The master caches the state it sends and
+uses the cached state if another node joins the cluster soon after. This
+setting controls how long the master waits until it clears this cache. Defaults
+to `60s`.
+
 [[no-master-block]]
 `cluster.no_master_block`::
 (<<dynamic-cluster-setting,Dynamic>>)

+ 30 - 26
server/src/main/java/org/elasticsearch/cluster/coordination/Coordinator.java

@@ -127,6 +127,7 @@ public class Coordinator extends AbstractLifecycleComponent implements ClusterSt
     private final MasterService masterService;
     private final AllocationService allocationService;
     private final JoinHelper joinHelper;
+    private final JoinValidationService joinValidationService;
     private final NodeRemovalClusterStateTaskExecutor nodeRemovalExecutor;
     private final Supplier<CoordinationState.PersistedState> persistedStateSupplier;
     private final NoMasterBlockService noMasterBlockService;
@@ -195,19 +196,22 @@ public class Coordinator extends AbstractLifecycleComponent implements ClusterSt
         this.electionStrategy = electionStrategy;
         this.joinReasonService = new JoinReasonService(transportService.getThreadPool()::relativeTimeInMillis);
         this.joinHelper = new JoinHelper(
-            settings,
             allocationService,
             masterService,
             transportService,
             this::getCurrentTerm,
-            this::getStateForMasterService,
             this::handleJoinRequest,
             this::joinLeaderInTerm,
-            this.onJoinValidators,
             rerouteService,
             nodeHealthService,
             joinReasonService
         );
+        this.joinValidationService = new JoinValidationService(
+            settings,
+            transportService,
+            this::getStateForMasterService,
+            this.onJoinValidators
+        );
         this.persistedStateSupplier = persistedStateSupplier;
         this.noMasterBlockService = new NoMasterBlockService(settings, clusterSettings);
         this.lastKnownLeader = Optional.empty();
@@ -623,7 +627,7 @@ public class Coordinator extends AbstractLifecycleComponent implements ClusterSt
         // Before letting the node join the cluster, ensure:
         // - it's a new enough version to pass the version barrier
         // - we have a healthy STATE channel to the node
-        // - if we're already master that it can make sense of a the current cluster state.
+        // - if we're already master that it can make sense of the current cluster state.
         // - we have a healthy PING channel to the node
 
         final ListenableActionFuture<Empty> validateStateListener = new ListenableActionFuture<>();
@@ -638,7 +642,7 @@ public class Coordinator extends AbstractLifecycleComponent implements ClusterSt
                     stateForJoinValidation.getNodes().getMinNodeVersion()
                 );
             }
-            sendJoinValidate(joinRequest.getSourceNode(), stateForJoinValidation, validateStateListener);
+            sendJoinValidate(joinRequest.getSourceNode(), validateStateListener);
         } else {
             sendJoinPing(joinRequest.getSourceNode(), TransportRequestOptions.Type.STATE, validateStateListener);
         }
@@ -669,27 +673,21 @@ public class Coordinator extends AbstractLifecycleComponent implements ClusterSt
         });
     }
 
-    private void sendJoinValidate(DiscoveryNode discoveryNode, ClusterState clusterState, ActionListener<Empty> listener) {
-        transportService.sendRequest(
-            discoveryNode,
-            JoinHelper.JOIN_VALIDATE_ACTION_NAME,
-            new ValidateJoinRequest(clusterState),
-            TransportRequestOptions.of(null, TransportRequestOptions.Type.STATE),
-            new ActionListenerResponseHandler<>(listener.delegateResponse((l, e) -> {
-                logger.warn(() -> new ParameterizedMessage("failed to validate incoming join request from node [{}]", discoveryNode), e);
-                listener.onFailure(
-                    new IllegalStateException(
-                        String.format(
-                            Locale.ROOT,
-                            "failure when sending a join validation request from [%s] to [%s]",
-                            getLocalNode().descriptionWithoutAttributes(),
-                            discoveryNode.descriptionWithoutAttributes()
-                        ),
-                        e
-                    )
-                );
-            }), i -> Empty.INSTANCE, Names.CLUSTER_COORDINATION)
-        );
+    private void sendJoinValidate(DiscoveryNode discoveryNode, ActionListener<Empty> listener) {
+        joinValidationService.validateJoin(discoveryNode, listener.delegateResponse((delegate, e) -> {
+            logger.warn(new ParameterizedMessage("failed to validate incoming join request from node [{}]", discoveryNode), e);
+            delegate.onFailure(
+                new IllegalStateException(
+                    String.format(
+                        Locale.ROOT,
+                        "failure when sending a join validation request from [%s] to [%s]",
+                        getLocalNode().descriptionWithoutAttributes(),
+                        discoveryNode.descriptionWithoutAttributes()
+                    ),
+                    e
+                )
+            );
+        }));
     }
 
     private void sendJoinPing(DiscoveryNode discoveryNode, TransportRequestOptions.Type channelType, ActionListener<Empty> listener) {
@@ -959,6 +957,7 @@ public class Coordinator extends AbstractLifecycleComponent implements ClusterSt
     @Override
     protected void doStop() {
         configuredHostsResolver.stop();
+        joinValidationService.stop();
     }
 
     @Override
@@ -1455,6 +1454,11 @@ public class Coordinator extends AbstractLifecycleComponent implements ClusterSt
         return onJoinValidators;
     }
 
+    // for tests
+    boolean hasIdleJoinValidationService() {
+        return joinValidationService.isIdle();
+    }
+
     public enum Mode {
         CANDIDATE,
         LEADER,

+ 0 - 36
server/src/main/java/org/elasticsearch/cluster/coordination/JoinHelper.java

@@ -13,7 +13,6 @@ import org.apache.logging.log4j.Logger;
 import org.apache.logging.log4j.message.ParameterizedMessage;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.ChannelActionListener;
-import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.ClusterStateTaskConfig;
 import org.elasticsearch.cluster.NotMasterException;
 import org.elasticsearch.cluster.coordination.Coordinator.Mode;
@@ -22,13 +21,11 @@ import org.elasticsearch.cluster.routing.RerouteService;
 import org.elasticsearch.cluster.routing.allocation.AllocationService;
 import org.elasticsearch.cluster.service.MasterService;
 import org.elasticsearch.common.Priority;
-import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
 import org.elasticsearch.core.Releasable;
 import org.elasticsearch.core.Releasables;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.core.Tuple;
-import org.elasticsearch.env.Environment;
 import org.elasticsearch.monitor.NodeHealthService;
 import org.elasticsearch.monitor.StatusInfo;
 import org.elasticsearch.threadpool.ThreadPool;
@@ -43,7 +40,6 @@ import org.elasticsearch.transport.TransportResponseHandler;
 import org.elasticsearch.transport.TransportService;
 
 import java.util.ArrayList;
-import java.util.Collection;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -52,7 +48,6 @@ import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.BiConsumer;
 import java.util.function.Function;
 import java.util.function.LongSupplier;
-import java.util.function.Supplier;
 
 import static org.elasticsearch.monitor.StatusInfo.Status.UNHEALTHY;
 
@@ -62,7 +57,6 @@ public class JoinHelper {
 
     public static final String START_JOIN_ACTION_NAME = "internal:cluster/coordination/start_join";
     public static final String JOIN_ACTION_NAME = "internal:cluster/coordination/join";
-    public static final String JOIN_VALIDATE_ACTION_NAME = "internal:cluster/coordination/join/validate";
     public static final String JOIN_PING_ACTION_NAME = "internal:cluster/coordination/join/ping";
 
     private final AllocationService allocationService;
@@ -79,15 +73,12 @@ public class JoinHelper {
     private final Map<DiscoveryNode, Releasable> joinConnections = new HashMap<>(); // synchronized on itself
 
     JoinHelper(
-        Settings settings,
         AllocationService allocationService,
         MasterService masterService,
         TransportService transportService,
         LongSupplier currentTermSupplier,
-        Supplier<ClusterState> currentStateSupplier,
         BiConsumer<JoinRequest, ActionListener<Void>> joinHandler,
         Function<StartJoinRequest, Join> joinLeaderInTerm,
-        Collection<BiConsumer<DiscoveryNode, ClusterState>> joinValidators,
         RerouteService rerouteService,
         NodeHealthService nodeHealthService,
         JoinReasonService joinReasonService
@@ -134,33 +125,6 @@ public class JoinHelper {
             TransportRequest.Empty::new,
             (request, channel, task) -> channel.sendResponse(Empty.INSTANCE)
         );
-
-        final List<String> dataPaths = Environment.PATH_DATA_SETTING.get(settings);
-        transportService.registerRequestHandler(
-            JOIN_VALIDATE_ACTION_NAME,
-            ThreadPool.Names.CLUSTER_COORDINATION,
-            ValidateJoinRequest::new,
-            (request, channel, task) -> {
-                final ClusterState localState = currentStateSupplier.get();
-                if (localState.metadata().clusterUUIDCommitted()
-                    && localState.metadata().clusterUUID().equals(request.getState().metadata().clusterUUID()) == false) {
-                    throw new CoordinationStateRejectedException(
-                        "This node previously joined a cluster with UUID ["
-                            + localState.metadata().clusterUUID()
-                            + "] and is now trying to join a different cluster with UUID ["
-                            + request.getState().metadata().clusterUUID()
-                            + "]. This is forbidden and usually indicates an incorrect "
-                            + "discovery or cluster bootstrapping configuration. Note that the cluster UUID persists across restarts and "
-                            + "can only be changed by deleting the contents of the node's data "
-                            + (dataPaths.size() == 1 ? "path " : "paths ")
-                            + dataPaths
-                            + " which will also remove any data held by this node."
-                    );
-                }
-                joinValidators.forEach(action -> action.accept(transportService.getLocalNode(), request.getState()));
-                channel.sendResponse(Empty.INSTANCE);
-            }
-        );
     }
 
     boolean isJoinPending() {

+ 366 - 0
server/src/main/java/org/elasticsearch/cluster/coordination/JoinValidationService.java

@@ -0,0 +1,366 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.cluster.coordination;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.apache.logging.log4j.message.ParameterizedMessage;
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.Version;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.ActionListenerResponseHandler;
+import org.elasticsearch.action.ActionRunnable;
+import org.elasticsearch.cluster.ClusterState;
+import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.common.bytes.ReleasableBytesReference;
+import org.elasticsearch.common.compress.CompressorFactory;
+import org.elasticsearch.common.io.Streams;
+import org.elasticsearch.common.io.stream.OutputStreamStreamOutput;
+import org.elasticsearch.common.settings.Setting;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.util.concurrent.AbstractRunnable;
+import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
+import org.elasticsearch.core.AbstractRefCounted;
+import org.elasticsearch.core.RefCounted;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.env.Environment;
+import org.elasticsearch.node.NodeClosedException;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.transport.BytesTransportRequest;
+import org.elasticsearch.transport.TransportRequestOptions;
+import org.elasticsearch.transport.TransportResponse;
+import org.elasticsearch.transport.TransportService;
+
+import java.io.IOException;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Locale;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Queue;
+import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.BiConsumer;
+import java.util.function.Supplier;
+
+/**
+ * Coordinates the join validation process.
+ * <p>
+ * When a node requests to join an existing cluster, the master first sends it a copy of a recent cluster state to ensure that the new node
+ * can make sense of it (e.g. it has all the plugins it needs to even deserialize the state). Cluster states can be expensive to serialize:
+ * they are large, so we compress them, but the compression takes extra CPU. Also there may be many nodes all joining at once (e.g. after a
+ * full cluster restart or the healing of a large network partition). This component caches the serialized and compressed state that was
+ * sent to one joining node and reuses it to validate other join requests that arrive within the cache timeout, avoiding the need to
+ * allocate memory for each request and repeat all that serialization and compression work each time.
+ */
+public class JoinValidationService {
+
+    /*
+     * IMPLEMENTATION NOTES
+     *
+     * This component is based around a queue of actions which are processed in a single-threaded fashion on the CLUSTER_COORDINATION
+     * threadpool. The actions are either:
+     *
+     * - send a join validation request to a particular node, or
+     * - clear the cache
+     *
+     * The single-threadedness is arranged by tracking (a lower bound on) the size of the queue in a separate AtomicInteger, and only
+     * spawning a new processor when the tracked queue size changes from 0 to 1.
+     *
+     * The executeRefs ref counter is necessary to handle the possibility of a concurrent shutdown, ensuring that the cache is always
+     * cleared even if validateJoin is called concurrently to the shutdown.
+     */
+
+    private static final Logger logger = LogManager.getLogger(JoinValidationService.class);
+
+    public static final String JOIN_VALIDATE_ACTION_NAME = "internal:cluster/coordination/join/validate";
+
+    // the timeout for each cached value
+    public static final Setting<TimeValue> JOIN_VALIDATION_CACHE_TIMEOUT_SETTING = Setting.timeSetting(
+        "cluster.join_validation.cache_timeout",
+        TimeValue.timeValueSeconds(60),
+        TimeValue.timeValueMillis(1),
+        Setting.Property.NodeScope
+    );
+
+    private static final TransportRequestOptions REQUEST_OPTIONS = TransportRequestOptions.of(null, TransportRequestOptions.Type.STATE);
+
+    private final TimeValue cacheTimeout;
+    private final TransportService transportService;
+    private final Supplier<ClusterState> clusterStateSupplier;
+    private final AtomicInteger queueSize = new AtomicInteger();
+    private final Queue<AbstractRunnable> queue = new ConcurrentLinkedQueue<>();
+    private final Map<Version, ReleasableBytesReference> statesByVersion = new HashMap<>();
+    private final RefCounted executeRefs;
+
+    public JoinValidationService(
+        Settings settings,
+        TransportService transportService,
+        Supplier<ClusterState> clusterStateSupplier,
+        Collection<BiConsumer<DiscoveryNode, ClusterState>> joinValidators
+    ) {
+        this.cacheTimeout = JOIN_VALIDATION_CACHE_TIMEOUT_SETTING.get(settings);
+        this.transportService = transportService;
+        this.clusterStateSupplier = clusterStateSupplier;
+        this.executeRefs = AbstractRefCounted.of(() -> execute(cacheClearer));
+
+        final var dataPaths = Environment.PATH_DATA_SETTING.get(settings);
+        transportService.registerRequestHandler(
+            JoinValidationService.JOIN_VALIDATE_ACTION_NAME,
+            ThreadPool.Names.CLUSTER_COORDINATION,
+            ValidateJoinRequest::new,
+            (request, channel, task) -> {
+                final var remoteState = request.getOrReadState();
+                final var localState = clusterStateSupplier.get();
+                if (localState.metadata().clusterUUIDCommitted()
+                    && localState.metadata().clusterUUID().equals(remoteState.metadata().clusterUUID()) == false) {
+                    throw new CoordinationStateRejectedException(
+                        "This node previously joined a cluster with UUID ["
+                            + localState.metadata().clusterUUID()
+                            + "] and is now trying to join a different cluster with UUID ["
+                            + remoteState.metadata().clusterUUID()
+                            + "]. This is forbidden and usually indicates an incorrect "
+                            + "discovery or cluster bootstrapping configuration. Note that the cluster UUID persists across restarts and "
+                            + "can only be changed by deleting the contents of the node's data "
+                            + (dataPaths.size() == 1 ? "path " : "paths ")
+                            + dataPaths
+                            + " which will also remove any data held by this node."
+                    );
+                }
+                joinValidators.forEach(joinValidator -> joinValidator.accept(transportService.getLocalNode(), remoteState));
+                channel.sendResponse(TransportResponse.Empty.INSTANCE);
+            }
+        );
+    }
+
+    public void validateJoin(DiscoveryNode discoveryNode, ActionListener<TransportResponse.Empty> listener) {
+        if (discoveryNode.getVersion().onOrAfter(Version.V_8_3_0)) {
+            if (executeRefs.tryIncRef()) {
+                try {
+                    execute(new JoinValidation(discoveryNode, listener));
+                } finally {
+                    executeRefs.decRef();
+                }
+            } else {
+                listener.onFailure(new NodeClosedException(transportService.getLocalNode()));
+            }
+        } else {
+            transportService.sendRequest(
+                discoveryNode,
+                JOIN_VALIDATE_ACTION_NAME,
+                new ValidateJoinRequest(clusterStateSupplier.get()),
+                REQUEST_OPTIONS,
+                new ActionListenerResponseHandler<>(listener.delegateResponse((l, e) -> {
+                    logger.warn(
+                        () -> new ParameterizedMessage("failed to validate incoming join request from node [{}]", discoveryNode),
+                        e
+                    );
+                    listener.onFailure(
+                        new IllegalStateException(
+                            String.format(
+                                Locale.ROOT,
+                                "failure when sending a join validation request from [%s] to [%s]",
+                                transportService.getLocalNode().descriptionWithoutAttributes(),
+                                discoveryNode.descriptionWithoutAttributes()
+                            ),
+                            e
+                        )
+                    );
+                }), i -> TransportResponse.Empty.INSTANCE, ThreadPool.Names.CLUSTER_COORDINATION)
+            );
+        }
+    }
+
+    public void stop() {
+        executeRefs.decRef();
+    }
+
+    boolean isIdle() {
+        // this is for single-threaded tests to assert that the service becomes idle, so it is not properly synchronized
+        return queue.isEmpty() && queueSize.get() == 0 && statesByVersion.isEmpty();
+    }
+
+    private void execute(AbstractRunnable task) {
+        assert task == cacheClearer || executeRefs.hasReferences();
+        queue.add(task);
+        if (queueSize.getAndIncrement() == 0) {
+            runProcessor();
+        }
+    }
+
+    private void runProcessor() {
+        transportService.getThreadPool().executor(ThreadPool.Names.CLUSTER_COORDINATION).execute(processor);
+    }
+
+    private final AbstractRunnable processor = new AbstractRunnable() {
+        @Override
+        protected void doRun() {
+            processNextItem();
+        }
+
+        @Override
+        public void onRejection(Exception e) {
+            assert e instanceof EsRejectedExecutionException esre && esre.isExecutorShutdown();
+            onShutdown();
+        }
+
+        @Override
+        public void onFailure(Exception e) {
+            logger.error("unexpectedly failed to process queue item", e);
+            assert false : e;
+        }
+
+        @Override
+        public String toString() {
+            return "process next task of join validation service";
+        }
+    };
+
+    private void processNextItem() {
+        if (executeRefs.hasReferences() == false) {
+            onShutdown();
+            return;
+        }
+
+        final var nextItem = queue.poll();
+        assert nextItem != null;
+        try {
+            nextItem.run();
+        } finally {
+            final var remaining = queueSize.decrementAndGet();
+            assert remaining >= 0;
+            if (remaining > 0) {
+                runProcessor();
+            }
+        }
+    }
+
+    private void onShutdown() {
+        // shutting down when enqueueing the next processor run which means there is no active processor so it's safe to clear out the
+        // cache ...
+        cacheClearer.run();
+
+        // ... and drain the queue
+        do {
+            final var nextItem = queue.poll();
+            assert nextItem != null;
+            if (nextItem != cacheClearer) {
+                nextItem.onFailure(new NodeClosedException(transportService.getLocalNode()));
+            }
+        } while (queueSize.decrementAndGet() > 0);
+    }
+
+    private final AbstractRunnable cacheClearer = new AbstractRunnable() {
+        @Override
+        public void onFailure(Exception e) {
+            logger.error("unexpectedly failed to clear cache", e);
+            assert false : e;
+        }
+
+        @Override
+        protected void doRun() {
+            // NB this never runs concurrently to JoinValidation actions, nor to itself, (see IMPLEMENTATION NOTES above) so it is safe
+            // to do these (non-atomic) things to the (unsynchronized) statesByVersion map.
+            for (final var bytes : statesByVersion.values()) {
+                bytes.decRef();
+            }
+            statesByVersion.clear();
+            logger.trace("join validation cache cleared");
+        }
+
+        @Override
+        public String toString() {
+            return "clear join validation cache";
+        }
+    };
+
+    private class JoinValidation extends ActionRunnable<TransportResponse.Empty> {
+        private final DiscoveryNode discoveryNode;
+
+        JoinValidation(DiscoveryNode discoveryNode, ActionListener<TransportResponse.Empty> listener) {
+            super(listener);
+            this.discoveryNode = discoveryNode;
+        }
+
+        @Override
+        protected void doRun() throws Exception {
+            assert discoveryNode.getVersion().onOrAfter(Version.V_8_3_0) : discoveryNode.getVersion();
+            // NB these things never run concurrently to each other, or to the cache cleaner (see IMPLEMENTATION NOTES above) so it is safe
+            // to do these (non-atomic) things to the (unsynchronized) statesByVersion map.
+            final var cachedBytes = statesByVersion.get(discoveryNode.getVersion());
+            final var bytes = Objects.requireNonNullElseGet(cachedBytes, () -> serializeClusterState(discoveryNode));
+            assert bytes.hasReferences() : "already closed";
+            bytes.incRef();
+            transportService.sendRequest(
+                discoveryNode,
+                JOIN_VALIDATE_ACTION_NAME,
+                new BytesTransportRequest(bytes, discoveryNode.getVersion()),
+                REQUEST_OPTIONS,
+                new ActionListenerResponseHandler<>(
+                    ActionListener.runAfter(listener, bytes::decRef),
+                    in -> TransportResponse.Empty.INSTANCE,
+                    ThreadPool.Names.CLUSTER_COORDINATION
+                )
+            );
+            if (cachedBytes == null) {
+                transportService.getThreadPool().schedule(new Runnable() {
+                    @Override
+                    public void run() {
+                        execute(cacheClearer);
+                    }
+
+                    @Override
+                    public String toString() {
+                        return cacheClearer + " after timeout";
+                    }
+                }, cacheTimeout, ThreadPool.Names.CLUSTER_COORDINATION);
+            }
+        }
+
+        @Override
+        public String toString() {
+            return "send cached join validation request to " + discoveryNode;
+        }
+    }
+
+    private ReleasableBytesReference serializeClusterState(DiscoveryNode discoveryNode) {
+        final var bytesStream = transportService.newNetworkBytesStream();
+        var success = false;
+        try {
+            final var clusterState = clusterStateSupplier.get();
+            final var version = discoveryNode.getVersion();
+            try (
+                var stream = new OutputStreamStreamOutput(
+                    CompressorFactory.COMPRESSOR.threadLocalOutputStream(Streams.flushOnCloseStream(bytesStream))
+                )
+            ) {
+                stream.setVersion(version);
+                clusterState.writeTo(stream);
+            } catch (IOException e) {
+                throw new ElasticsearchException("failed to serialize cluster state for publishing to node {}", e, discoveryNode);
+            }
+            final var newBytes = new ReleasableBytesReference(bytesStream.bytes(), bytesStream);
+            logger.trace(
+                "serialized join validation cluster state version [{}] for node version [{}] with size [{}]",
+                clusterState.version(),
+                version,
+                newBytes.length()
+            );
+            final var previousBytes = statesByVersion.put(version, newBytes);
+            success = true;
+            assert previousBytes == null;
+            return newBytes;
+        } finally {
+            if (success == false) {
+                assert false;
+                bytesStream.close();
+            }
+        }
+    }
+}

+ 65 - 6
server/src/main/java/org/elasticsearch/cluster/coordination/ValidateJoinRequest.java

@@ -7,32 +7,91 @@
  */
 package org.elasticsearch.cluster.coordination;
 
+import org.elasticsearch.Version;
 import org.elasticsearch.cluster.ClusterState;
+import org.elasticsearch.common.CheckedSupplier;
+import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.common.compress.CompressorFactory;
+import org.elasticsearch.common.io.stream.InputStreamStreamInput;
+import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.RefCounted;
 import org.elasticsearch.transport.TransportRequest;
 
 import java.io.IOException;
 
 public class ValidateJoinRequest extends TransportRequest {
-    private ClusterState state;
+    private final CheckedSupplier<ClusterState, IOException> stateSupplier;
+    private final RefCounted refCounted;
 
     public ValidateJoinRequest(StreamInput in) throws IOException {
         super(in);
-        this.state = ClusterState.readFrom(in, null);
+        if (in.getVersion().onOrAfter(Version.V_8_3_0)) {
+            // recent versions send a BytesTransportRequest containing a compressed representation of the state
+            final var bytes = in.readReleasableBytesReference();
+            final var version = in.getVersion();
+            final var namedWriteableRegistry = in.namedWriteableRegistry();
+            this.stateSupplier = () -> readCompressed(version, bytes, namedWriteableRegistry);
+            this.refCounted = bytes;
+        } else {
+            // older versions just contain the bare state
+            final var state = ClusterState.readFrom(in, null);
+            this.stateSupplier = () -> state;
+            this.refCounted = null;
+        }
+    }
+
+    private static ClusterState readCompressed(Version version, BytesReference bytes, NamedWriteableRegistry namedWriteableRegistry)
+        throws IOException {
+        try (
+            var bytesStreamInput = bytes.streamInput();
+            var in = new NamedWriteableAwareStreamInput(
+                new InputStreamStreamInput(CompressorFactory.COMPRESSOR.threadLocalInputStream(bytesStreamInput)),
+                namedWriteableRegistry
+            )
+        ) {
+            in.setVersion(version);
+            return ClusterState.readFrom(in, null);
+        }
     }
 
     public ValidateJoinRequest(ClusterState state) {
-        this.state = state;
+        this.stateSupplier = () -> state;
+        this.refCounted = null;
     }
 
     @Override
     public void writeTo(StreamOutput out) throws IOException {
+        assert out.getVersion().before(Version.V_8_3_0);
         super.writeTo(out);
-        this.state.writeTo(out);
+        stateSupplier.get().writeTo(out);
+    }
+
+    public ClusterState getOrReadState() throws IOException {
+        return stateSupplier.get();
+    }
+
+    @Override
+    public void incRef() {
+        if (refCounted != null) {
+            refCounted.incRef();
+        }
     }
 
-    public ClusterState getState() {
-        return state;
+    @Override
+    public boolean tryIncRef() {
+        return refCounted == null || refCounted.tryIncRef();
+    }
+
+    @Override
+    public boolean decRef() {
+        return refCounted != null && refCounted.decRef();
+    }
+
+    @Override
+    public boolean hasReferences() {
+        return refCounted == null || refCounted.hasReferences();
     }
 }

+ 2 - 0
server/src/main/java/org/elasticsearch/common/settings/ClusterSettings.java

@@ -27,6 +27,7 @@ import org.elasticsearch.cluster.coordination.ClusterFormationFailureHelper;
 import org.elasticsearch.cluster.coordination.Coordinator;
 import org.elasticsearch.cluster.coordination.ElectionSchedulerFactory;
 import org.elasticsearch.cluster.coordination.FollowersChecker;
+import org.elasticsearch.cluster.coordination.JoinValidationService;
 import org.elasticsearch.cluster.coordination.LagDetector;
 import org.elasticsearch.cluster.coordination.LeaderChecker;
 import org.elasticsearch.cluster.coordination.NoMasterBlockService;
@@ -490,6 +491,7 @@ public final class ClusterSettings extends AbstractScopedSettings {
         ElectionSchedulerFactory.ELECTION_DURATION_SETTING,
         Coordinator.PUBLISH_TIMEOUT_SETTING,
         Coordinator.PUBLISH_INFO_TIMEOUT_SETTING,
+        JoinValidationService.JOIN_VALIDATION_CACHE_TIMEOUT_SETTING,
         FollowersChecker.FOLLOWER_CHECK_TIMEOUT_SETTING,
         FollowersChecker.FOLLOWER_CHECK_INTERVAL_SETTING,
         FollowersChecker.FOLLOWER_CHECK_RETRY_COUNT_SETTING,

+ 0 - 179
server/src/test/java/org/elasticsearch/cluster/coordination/JoinHelperTests.java

@@ -10,46 +10,28 @@ package org.elasticsearch.cluster.coordination;
 import org.apache.logging.log4j.Level;
 import org.elasticsearch.Build;
 import org.elasticsearch.Version;
-import org.elasticsearch.action.ActionListenerResponseHandler;
-import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.cluster.ClusterName;
-import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.NotMasterException;
-import org.elasticsearch.cluster.SimpleDiffable;
-import org.elasticsearch.cluster.metadata.Metadata;
 import org.elasticsearch.cluster.node.DiscoveryNode;
-import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.concurrent.DeterministicTaskQueue;
-import org.elasticsearch.core.Releasable;
-import org.elasticsearch.core.Releasables;
 import org.elasticsearch.core.TimeValue;
-import org.elasticsearch.env.Environment;
 import org.elasticsearch.monitor.StatusInfo;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.transport.CapturingTransport;
 import org.elasticsearch.test.transport.CapturingTransport.CapturedRequest;
-import org.elasticsearch.test.transport.MockTransport;
-import org.elasticsearch.test.transport.MockTransportService;
-import org.elasticsearch.threadpool.TestThreadPool;
 import org.elasticsearch.threadpool.ThreadPool;
-import org.elasticsearch.transport.AbstractSimpleTransportTestCase;
 import org.elasticsearch.transport.ClusterConnectionManager;
 import org.elasticsearch.transport.RemoteTransportException;
 import org.elasticsearch.transport.TransportException;
 import org.elasticsearch.transport.TransportRequest;
 import org.elasticsearch.transport.TransportResponse;
 import org.elasticsearch.transport.TransportService;
-import org.elasticsearch.xcontent.XContentBuilder;
 
-import java.io.IOException;
-import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Optional;
-import java.util.concurrent.ExecutionException;
-import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
@@ -57,12 +39,8 @@ import java.util.stream.Stream;
 import static org.elasticsearch.cluster.coordination.JoinHelper.PENDING_JOIN_WAITING_RESPONSE;
 import static org.elasticsearch.monitor.StatusInfo.Status.HEALTHY;
 import static org.elasticsearch.monitor.StatusInfo.Status.UNHEALTHY;
-import static org.elasticsearch.transport.AbstractSimpleTransportTestCase.IGNORE_DESERIALIZATION_ERRORS_SETTING;
 import static org.elasticsearch.transport.TransportService.HANDSHAKE_ACTION_NAME;
-import static org.hamcrest.Matchers.allOf;
-import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
-import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.core.Is.is;
 
 public class JoinHelperTests extends ESTestCase {
@@ -83,15 +61,12 @@ public class JoinHelperTests extends ESTestCase {
             new ClusterConnectionManager(Settings.EMPTY, capturingTransport, threadPool.getThreadContext())
         );
         JoinHelper joinHelper = new JoinHelper(
-            Settings.EMPTY,
             null,
             null,
             transportService,
             () -> 0L,
-            () -> null,
             (joinRequest, joinCallback) -> { throw new AssertionError(); },
             startJoinRequest -> { throw new AssertionError(); },
-            Collections.emptyList(),
             (s, p, r) -> {},
             () -> new StatusInfo(HEALTHY, "info"),
             new JoinReasonService(() -> 0L)
@@ -226,70 +201,6 @@ public class JoinHelperTests extends ESTestCase {
         );
     }
 
-    public void testJoinValidationRejectsMismatchedClusterUUID() {
-        DeterministicTaskQueue deterministicTaskQueue = new DeterministicTaskQueue();
-        MockTransport mockTransport = new MockTransport();
-        DiscoveryNode localNode = new DiscoveryNode("node0", buildNewFakeTransportAddress(), Version.CURRENT);
-
-        final ClusterState localClusterState = ClusterState.builder(ClusterName.DEFAULT)
-            .metadata(Metadata.builder().generateClusterUuidIfNeeded().clusterUUIDCommitted(true))
-            .build();
-
-        TransportService transportService = mockTransport.createTransportService(
-            Settings.EMPTY,
-            deterministicTaskQueue.getThreadPool(),
-            TransportService.NOOP_TRANSPORT_INTERCEPTOR,
-            x -> localNode,
-            null,
-            Collections.emptySet()
-        );
-        final String dataPath = "/my/data/path";
-        new JoinHelper(
-            Settings.builder().put(Environment.PATH_DATA_SETTING.getKey(), dataPath).build(),
-            null,
-            null,
-            transportService,
-            () -> 0L,
-            () -> localClusterState,
-            (joinRequest, joinCallback) -> { throw new AssertionError(); },
-            startJoinRequest -> { throw new AssertionError(); },
-            Collections.emptyList(),
-            (s, p, r) -> {},
-            null,
-            new JoinReasonService(() -> 0L)
-        ); // registers request handler
-        transportService.start();
-        transportService.acceptIncomingRequests();
-
-        final ClusterState otherClusterState = ClusterState.builder(ClusterName.DEFAULT)
-            .metadata(Metadata.builder().generateClusterUuidIfNeeded())
-            .build();
-
-        final PlainActionFuture<TransportResponse.Empty> future = new PlainActionFuture<>();
-        transportService.sendRequest(
-            localNode,
-            JoinHelper.JOIN_VALIDATE_ACTION_NAME,
-            new ValidateJoinRequest(otherClusterState),
-            new ActionListenerResponseHandler<>(future, in -> TransportResponse.Empty.INSTANCE)
-        );
-        deterministicTaskQueue.runAllTasks();
-
-        final CoordinationStateRejectedException coordinationStateRejectedException = expectThrows(
-            CoordinationStateRejectedException.class,
-            future::actionGet
-        );
-        assertThat(
-            coordinationStateRejectedException.getMessage(),
-            allOf(
-                containsString("This node previously joined a cluster with UUID"),
-                containsString("and is now trying to join a different cluster"),
-                containsString(localClusterState.metadata().clusterUUID()),
-                containsString(otherClusterState.metadata().clusterUUID()),
-                containsString("data path [" + dataPath + "]")
-            )
-        );
-    }
-
     public void testJoinFailureOnUnhealthyNodes() {
         DeterministicTaskQueue deterministicTaskQueue = new DeterministicTaskQueue();
         CapturingTransport capturingTransport = new HandshakingCapturingTransport();
@@ -304,15 +215,12 @@ public class JoinHelperTests extends ESTestCase {
         );
         AtomicReference<StatusInfo> nodeHealthServiceStatus = new AtomicReference<>(new StatusInfo(UNHEALTHY, "unhealthy-info"));
         JoinHelper joinHelper = new JoinHelper(
-            Settings.EMPTY,
             null,
             null,
             transportService,
             () -> 0L,
-            () -> null,
             (joinRequest, joinCallback) -> { throw new AssertionError(); },
             startJoinRequest -> { throw new AssertionError(); },
-            Collections.emptyList(),
             (s, p, r) -> {},
             nodeHealthServiceStatus::get,
             new JoinReasonService(() -> 0L)
@@ -356,72 +264,6 @@ public class JoinHelperTests extends ESTestCase {
         assertEquals(node1, capturedRequest1a.node());
     }
 
-    public void testJoinValidationFailsOnUnreadableClusterState() {
-        final List<Releasable> releasables = new ArrayList<>(3);
-        try {
-            final ThreadPool threadPool = new TestThreadPool("test");
-            releasables.add(() -> ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS));
-
-            final TransportService remoteTransportService = MockTransportService.createNewService(
-                Settings.builder().put(IGNORE_DESERIALIZATION_ERRORS_SETTING.getKey(), true).build(),
-                Version.CURRENT,
-                threadPool
-            );
-            releasables.add(remoteTransportService);
-
-            new JoinHelper(
-                Settings.EMPTY,
-                null,
-                null,
-                remoteTransportService,
-                () -> 0L,
-                () -> null,
-                (joinRequest, joinCallback) -> { throw new AssertionError(); },
-                startJoinRequest -> { throw new AssertionError(); },
-                Collections.emptyList(),
-                (s, p, r) -> {},
-                () -> { throw new AssertionError(); },
-                new JoinReasonService(() -> 0L)
-            );
-
-            remoteTransportService.start();
-            remoteTransportService.acceptIncomingRequests();
-
-            final TransportService localTransportService = MockTransportService.createNewService(
-                Settings.EMPTY,
-                Version.CURRENT,
-                threadPool
-            );
-            releasables.add(localTransportService);
-
-            localTransportService.start();
-            localTransportService.acceptIncomingRequests();
-
-            AbstractSimpleTransportTestCase.connectToNode(localTransportService, remoteTransportService.getLocalNode());
-
-            final PlainActionFuture<TransportResponse.Empty> future = new PlainActionFuture<>();
-            localTransportService.sendRequest(
-                remoteTransportService.getLocalNode(),
-                JoinHelper.JOIN_VALIDATE_ACTION_NAME,
-                new ValidateJoinRequest(ClusterState.builder(ClusterName.DEFAULT).putCustom("test", new BadCustom()).build()),
-                new ActionListenerResponseHandler<>(future, in -> TransportResponse.Empty.INSTANCE)
-            );
-
-            final RemoteTransportException exception = expectThrows(
-                ExecutionException.class,
-                RemoteTransportException.class,
-                () -> future.get(10, TimeUnit.SECONDS)
-            );
-            assertThat(exception, instanceOf(RemoteTransportException.class));
-            assertThat(exception.getCause(), instanceOf(IllegalArgumentException.class));
-            assertThat(exception.getCause().getMessage(), containsString("Unknown NamedWriteable"));
-
-        } finally {
-            Collections.reverse(releasables);
-            Releasables.close(releasables);
-        }
-    }
-
     private static class HandshakingCapturingTransport extends CapturingTransport {
 
         @Override
@@ -436,25 +278,4 @@ public class JoinHelperTests extends ESTestCase {
             }
         }
     }
-
-    private static class BadCustom implements SimpleDiffable<ClusterState.Custom>, ClusterState.Custom {
-
-        @Override
-        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
-            return builder;
-        }
-
-        @Override
-        public String getWriteableName() {
-            return "deliberately-unknown";
-        }
-
-        @Override
-        public Version getMinimalSupportedVersion() {
-            return Version.CURRENT;
-        }
-
-        @Override
-        public void writeTo(StreamOutput out) throws IOException {}
-    }
 }

+ 418 - 0
server/src/test/java/org/elasticsearch/cluster/coordination/JoinValidationServiceTests.java

@@ -0,0 +1,418 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.cluster.coordination;
+
+import org.elasticsearch.Build;
+import org.elasticsearch.Version;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.ActionListenerResponseHandler;
+import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.cluster.ClusterName;
+import org.elasticsearch.cluster.ClusterState;
+import org.elasticsearch.cluster.SimpleDiffable;
+import org.elasticsearch.cluster.metadata.Metadata;
+import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.common.component.Lifecycle;
+import org.elasticsearch.common.io.stream.BytesStreamOutput;
+import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.util.concurrent.AbstractRunnable;
+import org.elasticsearch.common.util.concurrent.DeterministicTaskQueue;
+import org.elasticsearch.core.Releasable;
+import org.elasticsearch.core.Releasables;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.env.Environment;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.transport.MockTransport;
+import org.elasticsearch.threadpool.TestThreadPool;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.transport.CloseableConnection;
+import org.elasticsearch.transport.RemoteTransportException;
+import org.elasticsearch.transport.TestTransportChannel;
+import org.elasticsearch.transport.TransportException;
+import org.elasticsearch.transport.TransportRequest;
+import org.elasticsearch.transport.TransportRequestOptions;
+import org.elasticsearch.transport.TransportResponse;
+import org.elasticsearch.transport.TransportService;
+import org.elasticsearch.xcontent.XContentBuilder;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Random;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.CyclicBarrier;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+import static org.hamcrest.Matchers.allOf;
+import static org.hamcrest.Matchers.containsString;
+
+public class JoinValidationServiceTests extends ESTestCase {
+
+    public void testConcurrentBehaviour() throws Exception {
+        final var releasables = new ArrayList<Releasable>();
+        try {
+            final var settingsBuilder = Settings.builder();
+            settingsBuilder.put(
+                JoinValidationService.JOIN_VALIDATION_CACHE_TIMEOUT_SETTING.getKey(),
+                TimeValue.timeValueMillis(between(1, 1000))
+            );
+            if (randomBoolean()) {
+                settingsBuilder.put("thread_pool.cluster_coordination.size", between(1, 5));
+            }
+            final var settings = settingsBuilder.build();
+
+            final var threadPool = new TestThreadPool("test", settings);
+            releasables.add(() -> ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS));
+
+            final var sendCountdown = new CountDownLatch(between(0, 10));
+
+            final var transport = new MockTransport() {
+                @Override
+                public Connection createConnection(DiscoveryNode node) {
+                    return new CloseableConnection() {
+                        @Override
+                        public DiscoveryNode getNode() {
+                            return node;
+                        }
+
+                        @Override
+                        public void sendRequest(long requestId, String action, TransportRequest request, TransportRequestOptions options)
+                            throws TransportException {
+                            final var executor = threadPool.executor(
+                                randomFrom(ThreadPool.Names.SAME, ThreadPool.Names.GENERIC, ThreadPool.Names.CLUSTER_COORDINATION)
+                            );
+                            executor.execute(new AbstractRunnable() {
+                                @Override
+                                public void onFailure(Exception e) {
+                                    assert false : e;
+                                }
+
+                                @Override
+                                public void onRejection(Exception e) {
+                                    handleError(requestId, new TransportException(e));
+                                }
+
+                                @Override
+                                public void doRun() {
+                                    handleResponse(requestId, switch (action) {
+                                        case JoinValidationService.JOIN_VALIDATE_ACTION_NAME -> TransportResponse.Empty.INSTANCE;
+                                        case TransportService.HANDSHAKE_ACTION_NAME -> new TransportService.HandshakeResponse(
+                                            Version.CURRENT,
+                                            Build.CURRENT.hash(),
+                                            node,
+                                            ClusterName.DEFAULT
+                                        );
+                                        default -> throw new AssertionError("unexpected action: " + action);
+                                    });
+                                    sendCountdown.countDown();
+                                }
+                            });
+                        }
+                    };
+                }
+            };
+
+            final var localNode = new DiscoveryNode("local", buildNewFakeTransportAddress(), Version.CURRENT);
+
+            final var transportService = new TransportService(
+                settings,
+                transport,
+                threadPool,
+                TransportService.NOOP_TRANSPORT_INTERCEPTOR,
+                ignored -> localNode,
+                null,
+                Set.of()
+            );
+            releasables.add(transportService);
+
+            final var clusterState = ClusterState.EMPTY_STATE;
+
+            final var joinValidationService = new JoinValidationService(settings, transportService, () -> clusterState, List.of());
+
+            transportService.start();
+            releasables.add(() -> {
+                if (transportService.lifecycleState() == Lifecycle.State.STARTED) {
+                    transportService.stop();
+                }
+            });
+
+            transportService.acceptIncomingRequests();
+
+            final var otherNodes = new DiscoveryNode[between(1, 10)];
+            for (int i = 0; i < otherNodes.length; i++) {
+                otherNodes[i] = new DiscoveryNode("other-" + i, buildNewFakeTransportAddress(), Version.CURRENT);
+                final var connectionListener = new PlainActionFuture<Releasable>();
+                transportService.connectToNode(otherNodes[i], connectionListener);
+                releasables.add(connectionListener.get(10, TimeUnit.SECONDS));
+            }
+
+            final var threads = new Thread[between(1, 3)];
+            final var startBarrier = new CyclicBarrier(threads.length + 1);
+            final var permitCount = 100; // prevent too many concurrent requests or else cleanup can take ages
+            final var validationPermits = new Semaphore(permitCount);
+            final var expectFailures = new AtomicBoolean(false);
+            final var keepGoing = new AtomicBoolean(true);
+            for (int i = 0; i < threads.length; i++) {
+                final var seed = randomLong();
+                threads[i] = new Thread(() -> {
+                    final var random = new Random(seed);
+                    try {
+                        startBarrier.await(10, TimeUnit.SECONDS);
+                    } catch (Exception e) {
+                        throw new AssertionError(e);
+                    }
+
+                    while (keepGoing.get()) {
+                        Thread.yield();
+                        if (validationPermits.tryAcquire()) {
+                            joinValidationService.validateJoin(
+                                randomFrom(random, otherNodes),
+                                ActionListener.notifyOnce(new ActionListener<>() {
+                                    @Override
+                                    public void onResponse(TransportResponse.Empty empty) {
+                                        validationPermits.release();
+                                    }
+
+                                    @Override
+                                    public void onFailure(Exception e) {
+                                        validationPermits.release();
+                                        assert expectFailures.get() : e;
+                                    }
+                                })
+                            );
+                        }
+                    }
+                }, "join-validating-thread-" + i);
+                threads[i].start();
+            }
+
+            startBarrier.await(10, TimeUnit.SECONDS);
+            assertTrue(sendCountdown.await(10, TimeUnit.SECONDS));
+
+            expectFailures.set(true);
+            switch (between(1, 3)) {
+                case 1 -> joinValidationService.stop();
+                case 2 -> {
+                    joinValidationService.stop();
+                    transportService.close();
+                    ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS);
+                }
+                case 3 -> {
+                    transportService.close();
+                    keepGoing.set(false); // else the test threads keep adding to the validation service queue so the processor never stops
+                    ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS);
+                    joinValidationService.stop();
+                }
+            }
+            keepGoing.set(false);
+            for (final var thread : threads) {
+                thread.join();
+            }
+            assertTrue(validationPermits.tryAcquire(permitCount, 10, TimeUnit.SECONDS));
+            assertBusy(() -> assertTrue(joinValidationService.isIdle()));
+        } finally {
+            Collections.reverse(releasables);
+            Releasables.close(releasables);
+        }
+    }
+
+    public void testJoinValidationRejectsUnreadableClusterState() {
+
+        class BadCustom implements SimpleDiffable<ClusterState.Custom>, ClusterState.Custom {
+
+            @Override
+            public XContentBuilder toXContent(XContentBuilder builder, Params params) {
+                return builder;
+            }
+
+            @Override
+            public String getWriteableName() {
+                return "deliberately-unknown";
+            }
+
+            @Override
+            public Version getMinimalSupportedVersion() {
+                return Version.CURRENT;
+            }
+
+            @Override
+            public void writeTo(StreamOutput out) {}
+        }
+
+        final var deterministicTaskQueue = new DeterministicTaskQueue();
+        final var clusterState = ClusterState.builder(ClusterName.DEFAULT).putCustom("test", new BadCustom()).build();
+
+        final var joiningNode = new DiscoveryNode("joining", buildNewFakeTransportAddress(), Version.CURRENT);
+        final var joiningNodeTransport = new MockTransport();
+        final var joiningNodeTransportService = joiningNodeTransport.createTransportService(
+            Settings.EMPTY,
+            deterministicTaskQueue.getThreadPool(),
+            TransportService.NOOP_TRANSPORT_INTERCEPTOR,
+            x -> joiningNode,
+            null,
+            Collections.emptySet()
+        );
+        new JoinValidationService(Settings.EMPTY, joiningNodeTransportService, () -> clusterState, List.of()); // registers request handler
+        joiningNodeTransportService.start();
+        joiningNodeTransportService.acceptIncomingRequests();
+
+        final var masterNode = new DiscoveryNode("node0", buildNewFakeTransportAddress(), Version.CURRENT);
+        final var masterTransport = new MockTransport() {
+            @Override
+            protected void onSendRequest(long requestId, String action, TransportRequest request, DiscoveryNode node) {
+                assertSame(node, joiningNode);
+                assertEquals(JoinValidationService.JOIN_VALIDATE_ACTION_NAME, action);
+
+                final var listener = new ActionListener<TransportResponse>() {
+                    @Override
+                    public void onResponse(TransportResponse transportResponse) {
+                        fail("should not succeed");
+                    }
+
+                    @Override
+                    public void onFailure(Exception e) {
+                        handleError(requestId, new RemoteTransportException(node.getName(), node.getAddress(), action, e));
+                    }
+                };
+
+                try (var out = new BytesStreamOutput()) {
+                    request.writeTo(out);
+                    out.flush();
+                    final var handler = joiningNodeTransport.getRequestHandlers().getHandler(action);
+                    handler.processMessageReceived(
+                        handler.newRequest(new NamedWriteableAwareStreamInput(out.bytes().streamInput(), writeableRegistry())),
+                        new TestTransportChannel(listener)
+                    );
+                } catch (Exception e) {
+                    listener.onFailure(e);
+                }
+            }
+        };
+        final var masterTransportService = masterTransport.createTransportService(
+            Settings.EMPTY,
+            deterministicTaskQueue.getThreadPool(),
+            TransportService.NOOP_TRANSPORT_INTERCEPTOR,
+            x -> masterNode,
+            null,
+            Collections.emptySet()
+        );
+        final var joinValidationService = new JoinValidationService(Settings.EMPTY, masterTransportService, () -> clusterState, List.of());
+        masterTransportService.start();
+        masterTransportService.acceptIncomingRequests();
+
+        try {
+            final var future = new PlainActionFuture<TransportResponse.Empty>();
+            joinValidationService.validateJoin(joiningNode, future);
+            assertFalse(future.isDone());
+            deterministicTaskQueue.runAllTasks();
+            assertTrue(future.isDone());
+            assertThat(
+                expectThrows(IllegalArgumentException.class, future::actionGet).getMessage(),
+                allOf(containsString("Unknown NamedWriteable"), containsString("deliberately-unknown"))
+            );
+        } finally {
+            joinValidationService.stop();
+            masterTransportService.close();
+            joiningNodeTransportService.close();
+        }
+    }
+
+    public void testJoinValidationRejectsMismatchedClusterUUID() {
+        final var deterministicTaskQueue = new DeterministicTaskQueue();
+        final var mockTransport = new MockTransport();
+        final var localNode = new DiscoveryNode("node0", buildNewFakeTransportAddress(), Version.CURRENT);
+
+        final var localClusterState = ClusterState.builder(ClusterName.DEFAULT)
+            .metadata(Metadata.builder().generateClusterUuidIfNeeded().clusterUUIDCommitted(true))
+            .build();
+
+        final var transportService = mockTransport.createTransportService(
+            Settings.EMPTY,
+            deterministicTaskQueue.getThreadPool(),
+            TransportService.NOOP_TRANSPORT_INTERCEPTOR,
+            ignored -> localNode,
+            null,
+            Set.of()
+        );
+
+        final var dataPath = "/my/data/path";
+        final var settings = Settings.builder().put(Environment.PATH_DATA_SETTING.getKey(), dataPath).build();
+        new JoinValidationService(settings, transportService, () -> localClusterState, List.of()); // registers request handler
+        transportService.start();
+        transportService.acceptIncomingRequests();
+
+        final var otherClusterState = ClusterState.builder(ClusterName.DEFAULT)
+            .metadata(Metadata.builder().generateClusterUuidIfNeeded())
+            .build();
+
+        final var future = new PlainActionFuture<TransportResponse.Empty>();
+        transportService.sendRequest(
+            localNode,
+            JoinValidationService.JOIN_VALIDATE_ACTION_NAME,
+            new ValidateJoinRequest(otherClusterState),
+            new ActionListenerResponseHandler<>(future, in -> TransportResponse.Empty.INSTANCE)
+        );
+        deterministicTaskQueue.runAllTasks();
+
+        assertThat(
+            expectThrows(CoordinationStateRejectedException.class, future::actionGet).getMessage(),
+            allOf(
+                containsString("This node previously joined a cluster with UUID"),
+                containsString("and is now trying to join a different cluster"),
+                containsString(localClusterState.metadata().clusterUUID()),
+                containsString(otherClusterState.metadata().clusterUUID()),
+                containsString("data path [" + dataPath + "]")
+            )
+        );
+    }
+
+    public void testJoinValidationRunsJoinValidators() {
+        final var deterministicTaskQueue = new DeterministicTaskQueue();
+        final var mockTransport = new MockTransport();
+        final var localNode = new DiscoveryNode("node0", buildNewFakeTransportAddress(), Version.CURRENT);
+        final var localClusterState = ClusterState.builder(ClusterName.DEFAULT).build();
+
+        final var transportService = mockTransport.createTransportService(
+            Settings.EMPTY,
+            deterministicTaskQueue.getThreadPool(),
+            TransportService.NOOP_TRANSPORT_INTERCEPTOR,
+            ignored -> localNode,
+            null,
+            Set.of()
+        );
+
+        final var stateForValidation = ClusterState.builder(ClusterName.DEFAULT).build();
+        new JoinValidationService(Settings.EMPTY, transportService, () -> localClusterState, List.of((node, state) -> {
+            assertSame(node, localNode);
+            assertSame(state, stateForValidation);
+            throw new IllegalStateException("simulated validation failure");
+        })); // registers request handler
+        transportService.start();
+        transportService.acceptIncomingRequests();
+
+        final var future = new PlainActionFuture<TransportResponse.Empty>();
+        transportService.sendRequest(
+            localNode,
+            JoinValidationService.JOIN_VALIDATE_ACTION_NAME,
+            new ValidateJoinRequest(stateForValidation),
+            new ActionListenerResponseHandler<>(future, in -> TransportResponse.Empty.INSTANCE)
+        );
+        deterministicTaskQueue.runAllTasks();
+
+        assertThat(
+            expectThrows(IllegalStateException.class, future::actionGet).getMessage(),
+            allOf(containsString("simulated validation failure"))
+        );
+    }
+}

+ 13 - 6
server/src/test/java/org/elasticsearch/cluster/coordination/NodeJoinTests.java

@@ -93,8 +93,14 @@ public class NodeJoinTests extends ESTestCase {
 
     @After
     public void tearDown() throws Exception {
-        super.tearDown();
+        masterService.stop();
+        coordinator.stop();
+        if (deterministicTaskQueue != null) {
+            deterministicTaskQueue.runAllRunnableTasks();
+        }
         masterService.close();
+        coordinator.close();
+        super.tearDown();
     }
 
     private static ClusterState initialState(DiscoveryNode localNode, long term, long version, VotingConfiguration config) {
@@ -182,11 +188,12 @@ public class NodeJoinTests extends ESTestCase {
                             initialState.getClusterName()
                         )
                     );
-                } else if (action.equals(JoinHelper.JOIN_VALIDATE_ACTION_NAME) || action.equals(JoinHelper.JOIN_PING_ACTION_NAME)) {
-                    handleResponse(requestId, new TransportResponse.Empty());
-                } else {
-                    super.onSendRequest(requestId, action, request, destination);
-                }
+                } else if (action.equals(JoinValidationService.JOIN_VALIDATE_ACTION_NAME)
+                    || action.equals(JoinHelper.JOIN_PING_ACTION_NAME)) {
+                        handleResponse(requestId, new TransportResponse.Empty());
+                    } else {
+                        super.onSendRequest(requestId, action, request, destination);
+                    }
             }
         };
         final ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);

+ 18 - 4
test/framework/src/main/java/org/elasticsearch/cluster/coordination/AbstractCoordinatorTestCase.java

@@ -245,7 +245,9 @@ public class AbstractCoordinatorTestCase extends ESTestCase {
                 FOLLOWER_CHECK_RETRY_COUNT_SETTING
             )
             // then wait for the new leader to commit a state without the old leader
-            + DEFAULT_CLUSTER_STATE_UPDATE_DELAY;
+            + DEFAULT_CLUSTER_STATE_UPDATE_DELAY
+            // then wait for the join validation service to become idle
+            + defaultMillis(JoinValidationService.JOIN_VALIDATION_CACHE_TIMEOUT_SETTING);
 
     public class Cluster implements Releasable {
 
@@ -535,10 +537,14 @@ public class AbstractCoordinatorTestCase extends ESTestCase {
         }
 
         public void stabilise() {
-            stabilise(DEFAULT_STABILISATION_TIME);
+            stabilise(DEFAULT_STABILISATION_TIME, true);
+        }
+
+        public void stabilise(long stabilisationDurationMillis) {
+            stabilise(stabilisationDurationMillis, false);
         }
 
-        void stabilise(long stabilisationDurationMillis) {
+        private void stabilise(long stabilisationDurationMillis, boolean expectIdleJoinValidationService) {
             assertThat(
                 "stabilisation requires default delay variability (and proper cleanup of raised variability)",
                 deterministicTaskQueue.getExecutionDelayVariabilityMillis(),
@@ -657,6 +663,13 @@ public class AbstractCoordinatorTestCase extends ESTestCase {
                         }
                     }
                 }
+
+                if (expectIdleJoinValidationService) {
+                    // Tests run stabilise(long stabilisationDurationMillis) to assert timely recovery from a disruption. There's no need
+                    // to wait for the JoinValidationService cache to be cleared in these cases, we have enough checks that this eventually
+                    // happens anyway.
+                    assertTrue(nodeId + " has an idle join validation service", clusterNode.coordinator.hasIdleJoinValidationService());
+                }
             }
 
             final Set<String> connectedNodeIds = clusterNodes.stream()
@@ -1138,7 +1151,8 @@ public class AbstractCoordinatorTestCase extends ESTestCase {
                                 chanType,
                                 equalTo(TransportRequestOptions.Type.PING)
                             );
-                            case JoinHelper.JOIN_VALIDATE_ACTION_NAME, PublicationTransportHandler.PUBLISH_STATE_ACTION_NAME,
+                            case JoinValidationService.JOIN_VALIDATE_ACTION_NAME,
+                                 PublicationTransportHandler.PUBLISH_STATE_ACTION_NAME,
                                  Coordinator.COMMIT_STATE_ACTION_NAME -> assertThat(
                                 action,
                                 chanType,