Jelajahi Sumber

Asynchronously connect to remote clusters (#44825)

Refactors RemoteClusterConnection so that it no longer blockingly connects to remote clusters.

Relates to #40150
Yannick Welsch 6 tahun lalu
induk
melakukan
ae486e4911

+ 184 - 209
server/src/main/java/org/elasticsearch/transport/RemoteClusterConnection.java

@@ -25,11 +25,11 @@ import org.apache.lucene.store.AlreadyClosedException;
 import org.apache.lucene.util.SetOnce;
 import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.StepListener;
 import org.elasticsearch.action.admin.cluster.state.ClusterStateAction;
 import org.elasticsearch.action.admin.cluster.state.ClusterStateRequest;
 import org.elasticsearch.action.admin.cluster.state.ClusterStateResponse;
 import org.elasticsearch.action.support.ContextPreservingActionListener;
-import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.cluster.ClusterName;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNodes;
@@ -38,7 +38,6 @@ import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.transport.TransportAddress;
 import org.elasticsearch.common.unit.TimeValue;
-import org.elasticsearch.common.util.CancellableThreads;
 import org.elasticsearch.common.util.concurrent.AbstractRunnable;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.core.internal.io.IOUtils;
@@ -48,17 +47,14 @@ import java.io.Closeable;
 import java.io.IOException;
 import java.net.InetSocketAddress;
 import java.util.ArrayList;
-import java.util.Collection;
 import java.util.Collections;
 import java.util.Iterator;
 import java.util.List;
-import java.util.concurrent.ArrayBlockingQueue;
-import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.RejectedExecutionException;
-import java.util.concurrent.Semaphore;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicLong;
+import java.util.function.Consumer;
 import java.util.function.Function;
 import java.util.function.Predicate;
 import java.util.function.Supplier;
@@ -138,7 +134,7 @@ final class RemoteClusterConnection implements TransportConnectionListener, Clos
         if (proxyAddress == null || proxyAddress.isEmpty()) {
             return node;
         } else {
-            // resovle proxy address lazy here
+            // resolve proxy address lazy here
             InetSocketAddress proxyInetAddress = RemoteClusterAware.parseSeedAddress(proxyAddress);
             return new DiscoveryNode(node.getName(), node.getId(), node.getEphemeralId(), node.getHostName(), node
                 .getHostAddress(), new TransportAddress(proxyInetAddress), node.getAttributes(), node.getRoles(), node.getVersion());
@@ -175,7 +171,9 @@ final class RemoteClusterConnection implements TransportConnectionListener, Clos
     public void onNodeDisconnected(DiscoveryNode node) {
         if (connectionManager.size() < maxNumRemoteConnections) {
             // try to reconnect and fill up the slot of the disconnected node
-            connectHandler.forceConnect();
+            connectHandler.connect(ActionListener.wrap(
+                ignore -> logger.trace("successfully connected after disconnect of {}", node),
+                e -> logger.trace(() -> new ParameterizedMessage("failed to connect after disconnect of {}", node), e)));
         }
     }
 
@@ -357,201 +355,178 @@ final class RemoteClusterConnection implements TransportConnectionListener, Clos
      * we will just reject the connect trigger which will lead to failing searches.
      */
     private class ConnectHandler implements Closeable {
-        private final Semaphore running = new Semaphore(1);
+        private static final int MAX_LISTENERS = 100;
         private final AtomicBoolean closed = new AtomicBoolean(false);
-        private final BlockingQueue<ActionListener<Void>> queue = new ArrayBlockingQueue<>(100);
-        private final CancellableThreads cancellableThreads = new CancellableThreads();
-
-        /**
-         * Triggers a connect round iff there are pending requests queued up and if there is no
-         * connect round currently running.
-         */
-        void maybeConnect() {
-            connect(null);
-        }
+        private final Object mutex = new Object();
+        private List<ActionListener<Void>> listeners = new ArrayList<>();
 
         /**
          * Triggers a connect round unless there is one running already. If there is a connect round running, the listener will either
          * be queued or rejected and failed.
          */
         void connect(ActionListener<Void> connectListener) {
-            connect(connectListener, false);
-        }
-
-        /**
-         * Triggers a connect round unless there is one already running. In contrast to {@link #maybeConnect()} will this method also
-         * trigger a connect round if there is no listener queued up.
-         */
-        void forceConnect() {
-            connect(null, true);
-        }
-
-        private void connect(ActionListener<Void> connectListener, boolean forceRun) {
-            final boolean runConnect;
-            final Collection<ActionListener<Void>> toNotify;
-            final ActionListener<Void> listener = connectListener == null ? null :
+            boolean runConnect = false;
+            final ActionListener<Void> listener =
                 ContextPreservingActionListener.wrapPreservingContext(connectListener, threadPool.getThreadContext());
-            synchronized (queue) {
-                if (listener != null && queue.offer(listener) == false) {
-                    listener.onFailure(new RejectedExecutionException("connect queue is full"));
-                    return;
-                }
-                if (forceRun == false && queue.isEmpty()) {
-                    return;
-                }
-                runConnect = running.tryAcquire();
-                if (runConnect) {
-                    toNotify = new ArrayList<>();
-                    queue.drainTo(toNotify);
-                    if (closed.get()) {
-                        running.release();
-                        ActionListener.onFailure(toNotify, new AlreadyClosedException("connect handler is already closed"));
+            synchronized (mutex) {
+                if (closed.get()) {
+                    assert listeners.isEmpty();
+                } else {
+                    if (listeners.size() >= MAX_LISTENERS) {
+                        assert listeners.size() == MAX_LISTENERS;
+                        listener.onFailure(new RejectedExecutionException("connect queue is full"));
                         return;
+                    } else {
+                        listeners.add(listener);
                     }
-                } else {
-                    toNotify = Collections.emptyList();
+                    runConnect = listeners.size() == 1;
                 }
             }
-            if (runConnect) {
-                forkConnect(toNotify);
+            if (closed.get()) {
+                connectListener.onFailure(new AlreadyClosedException("connect handler is already closed"));
+                return;
             }
-        }
-
-        private void forkConnect(final Collection<ActionListener<Void>> toNotify) {
-            ExecutorService executor = threadPool.executor(ThreadPool.Names.MANAGEMENT);
-            executor.submit(new AbstractRunnable() {
-                @Override
-                public void onFailure(Exception e) {
-                    synchronized (queue) {
-                        running.release();
-                    }
-                    try {
-                        ActionListener.onFailure(toNotify, e);
-                    } finally {
-                        maybeConnect();
+            if (runConnect) {
+                ExecutorService executor = threadPool.executor(ThreadPool.Names.MANAGEMENT);
+                executor.submit(new AbstractRunnable() {
+                    @Override
+                    public void onFailure(Exception e) {
+                        ActionListener.onFailure(getAndClearListeners(), e);
                     }
-                }
 
-                @Override
-                protected void doRun() {
-                    ActionListener<Void> listener = ActionListener.wrap((x) -> {
-                        synchronized (queue) {
-                            running.release();
-                        }
-                        try {
-                            ActionListener.onResponse(toNotify, x);
-                        } finally {
-                            maybeConnect();
-                        }
+                    @Override
+                    protected void doRun() {
+                        collectRemoteNodes(seedNodes.stream().map(Tuple::v2).iterator(),
+                            new ActionListener<>() {
+                                @Override
+                                public void onResponse(Void aVoid) {
+                                    ActionListener.onResponse(getAndClearListeners(), aVoid);
+                                }
 
-                    }, (e) -> {
-                        synchronized (queue) {
-                            running.release();
-                        }
-                        try {
-                            ActionListener.onFailure(toNotify, e);
-                        } finally {
-                            maybeConnect();
-                        }
-                    });
-                    collectRemoteNodes(seedNodes.stream().map(Tuple::v2).iterator(), transportService, connectionManager, listener);
+                                @Override
+                                public void onFailure(Exception e) {
+                                    ActionListener.onFailure(getAndClearListeners(), e);
+                                }
+                            });
+                    }
+                });
+            }
+        }
+
+        private List<ActionListener<Void>> getAndClearListeners() {
+            final List<ActionListener<Void>> result;
+            synchronized (mutex) {
+                if (listeners.isEmpty()) {
+                    result = Collections.emptyList();
+                } else {
+                    result = listeners;
+                    listeners = new ArrayList<>();
                 }
-            });
+            }
+            return result;
         }
 
-        private void collectRemoteNodes(Iterator<Supplier<DiscoveryNode>> seedNodes, final TransportService transportService,
-                                        final ConnectionManager manager, ActionListener<Void> listener) {
+        private void collectRemoteNodes(Iterator<Supplier<DiscoveryNode>> seedNodes, ActionListener<Void> listener) {
             if (Thread.currentThread().isInterrupted()) {
                 listener.onFailure(new InterruptedException("remote connect thread got interrupted"));
             }
-            try {
-                if (seedNodes.hasNext()) {
-                    cancellableThreads.executeIO(() -> {
-                        final DiscoveryNode seedNode = maybeAddProxyAddress(proxyAddress, seedNodes.next().get());
-                        logger.debug("[{}] opening connection to seed node: [{}] proxy address: [{}]", clusterAlias, seedNode,
-                            proxyAddress);
-                        final TransportService.HandshakeResponse handshakeResponse;
-                        final ConnectionProfile profile = ConnectionProfile.buildSingleChannelProfile(TransportRequestOptions.Type.REG);
-                        final Transport.Connection connection = PlainActionFuture.get(
-                            fut -> manager.openConnection(seedNode, profile, fut));
-                        boolean success = false;
-                        try {
-                            try {
-                                ConnectionProfile connectionProfile = connectionManager.getConnectionProfile();
-                                handshakeResponse = PlainActionFuture.get(fut ->
-                                    transportService.handshake(connection, connectionProfile.getHandshakeTimeout().millis(),
-                                        getRemoteClusterNamePredicate(), fut));
-                            } catch (IllegalStateException ex) {
-                                logger.warn(new ParameterizedMessage("failed to connect to seed node [{}]", connection.getNode()), ex);
-                                throw ex;
-                            }
-
-                            final DiscoveryNode handshakeNode = maybeAddProxyAddress(proxyAddress, handshakeResponse.getDiscoveryNode());
-                            if (nodePredicate.test(handshakeNode) && manager.size() < maxNumRemoteConnections) {
-                                PlainActionFuture.get(fut -> manager.connectToNode(handshakeNode, null,
-                                    transportService.connectionValidator(handshakeNode), ActionListener.map(fut, x -> null)));
-                                if (remoteClusterName.get() == null) {
-                                    assert handshakeResponse.getClusterName().value() != null;
-                                    remoteClusterName.set(handshakeResponse.getClusterName());
-                                }
-                            }
-                            ClusterStateRequest request = new ClusterStateRequest();
-                            request.clear();
-                            request.nodes(true);
-                            // here we pass on the connection since we can only close it once the sendRequest returns otherwise
-                            // due to the async nature (it will return before it's actually sent) this can cause the request to fail
-                            // due to an already closed connection.
-                            ThreadPool threadPool = transportService.getThreadPool();
-                            ThreadContext threadContext = threadPool.getThreadContext();
-                            TransportService.ContextRestoreResponseHandler<ClusterStateResponse> responseHandler = new TransportService
-                                .ContextRestoreResponseHandler<>(threadContext.newRestorableContext(false),
-                                new SniffClusterStateResponseHandler(connection, listener, seedNodes,
-                                    cancellableThreads));
-                            try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
-                                // we stash any context here since this is an internal execution and should not leak any
-                                // existing context information.
-                                threadContext.markAsSystemContext();
-                                transportService.sendRequest(connection, ClusterStateAction.NAME, request, TransportRequestOptions.EMPTY,
-                                    responseHandler);
-                            }
-                            success = true;
-                        } finally {
-                            if (success == false) {
-                                connection.close();
-                            }
+
+            if (seedNodes.hasNext()) {
+                final DiscoveryNode seedNode = maybeAddProxyAddress(proxyAddress, seedNodes.next().get());
+                logger.debug("[{}] opening connection to seed node: [{}] proxy address: [{}]", clusterAlias, seedNode,
+                    proxyAddress);
+                final ConnectionProfile profile = ConnectionProfile.buildSingleChannelProfile(TransportRequestOptions.Type.REG);
+
+                final StepListener<Transport.Connection> openConnectionStep = new StepListener<>();
+                connectionManager.openConnection(seedNode, profile, openConnectionStep);
+
+                final Consumer<Exception> onFailure = e -> {
+                    if (e instanceof ConnectTransportException ||
+                        e instanceof IOException ||
+                        e instanceof IllegalStateException) {
+                        // ISE if we fail the handshake with an version incompatible node
+                        if (seedNodes.hasNext()) {
+                            logger.debug(() -> new ParameterizedMessage(
+                                "fetching nodes from external cluster [{}] failed moving to next node", clusterAlias), e);
+                            collectRemoteNodes(seedNodes, listener);
+                            return;
                         }
-                    });
-                } else {
-                    listener.onFailure(new IllegalStateException("no seed node left"));
-                }
-            } catch (CancellableThreads.ExecutionCancelledException ex) {
-                logger.warn(() -> new ParameterizedMessage("fetching nodes from external cluster [{}] failed", clusterAlias), ex);
-                listener.onFailure(ex); // we got canceled - fail the listener and step out
-            } catch (ConnectTransportException | IOException | IllegalStateException ex) {
-                // ISE if we fail the handshake with an version incompatible node
-                if (seedNodes.hasNext()) {
-                    logger.debug(() -> new ParameterizedMessage("fetching nodes from external cluster [{}] failed moving to next node",
-                        clusterAlias), ex);
-                    collectRemoteNodes(seedNodes, transportService, manager, listener);
-                } else {
-                    logger.warn(() -> new ParameterizedMessage("fetching nodes from external cluster [{}] failed", clusterAlias), ex);
-                    listener.onFailure(ex);
-                }
+                    }
+                    logger.warn(() -> new ParameterizedMessage("fetching nodes from external cluster [{}] failed", clusterAlias), e);
+                    listener.onFailure(e);
+                };
+
+                final StepListener<TransportService.HandshakeResponse> handShakeStep = new StepListener<>();
+                openConnectionStep.whenComplete(connection -> {
+                    ConnectionProfile connectionProfile = connectionManager.getConnectionProfile();
+                    transportService.handshake(connection, connectionProfile.getHandshakeTimeout().millis(),
+                        getRemoteClusterNamePredicate(), handShakeStep);
+                }, onFailure);
+
+                final StepListener<Void> fullConnectionStep = new StepListener<>();
+                handShakeStep.whenComplete(handshakeResponse -> {
+                    final DiscoveryNode handshakeNode = maybeAddProxyAddress(proxyAddress, handshakeResponse.getDiscoveryNode());
+
+                    if (nodePredicate.test(handshakeNode) && connectionManager.size() < maxNumRemoteConnections) {
+                        connectionManager.connectToNode(handshakeNode, null,
+                            transportService.connectionValidator(handshakeNode), fullConnectionStep);
+                    } else {
+                        fullConnectionStep.onResponse(null);
+                    }
+                }, e -> {
+                    final Transport.Connection connection = openConnectionStep.result();
+                    logger.warn(new ParameterizedMessage("failed to connect to seed node [{}]", connection.getNode()), e);
+                    IOUtils.closeWhileHandlingException(connection);
+                    onFailure.accept(e);
+                });
+
+                fullConnectionStep.whenComplete(aVoid -> {
+                    if (remoteClusterName.get() == null) {
+                        TransportService.HandshakeResponse handshakeResponse = handShakeStep.result();
+                        assert handshakeResponse.getClusterName().value() != null;
+                        remoteClusterName.set(handshakeResponse.getClusterName());
+                    }
+                    final Transport.Connection connection = openConnectionStep.result();
+
+                    ClusterStateRequest request = new ClusterStateRequest();
+                    request.clear();
+                    request.nodes(true);
+                    // here we pass on the connection since we can only close it once the sendRequest returns otherwise
+                    // due to the async nature (it will return before it's actually sent) this can cause the request to fail
+                    // due to an already closed connection.
+                    ThreadPool threadPool = transportService.getThreadPool();
+                    ThreadContext threadContext = threadPool.getThreadContext();
+                    TransportService.ContextRestoreResponseHandler<ClusterStateResponse> responseHandler = new TransportService
+                        .ContextRestoreResponseHandler<>(threadContext.newRestorableContext(false),
+                        new SniffClusterStateResponseHandler(connection, listener, seedNodes));
+                    try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
+                        // we stash any context here since this is an internal execution and should not leak any
+                        // existing context information.
+                        threadContext.markAsSystemContext();
+                        transportService.sendRequest(connection, ClusterStateAction.NAME, request, TransportRequestOptions.EMPTY,
+                            responseHandler);
+                    }
+                }, e -> {
+                    IOUtils.closeWhileHandlingException(openConnectionStep.result());
+                    onFailure.accept(e);
+                });
+            } else {
+                listener.onFailure(new IllegalStateException("no seed node left"));
             }
         }
 
         @Override
         public void close() throws IOException {
-            try {
+            final List<ActionListener<Void>> toNotify;
+            synchronized (mutex) {
                 if (closed.compareAndSet(false, true)) {
-                    cancellableThreads.cancel("connect handler is closed");
-                    running.acquire(); // acquire the semaphore to ensure all connections are closed and all thread joined
-                    running.release();
-                    maybeConnect(); // now go and notify pending listeners
+                    toNotify = listeners;
+                    listeners = Collections.emptyList();
+                } else {
+                    toNotify = Collections.emptyList();
                 }
-            } catch (InterruptedException e) {
-                Thread.currentThread().interrupt();
             }
+            ActionListener.onFailure(toNotify, new AlreadyClosedException("connect handler is already closed"));
         }
 
         final boolean isClosed() {
@@ -564,15 +539,12 @@ final class RemoteClusterConnection implements TransportConnectionListener, Clos
             private final Transport.Connection connection;
             private final ActionListener<Void> listener;
             private final Iterator<Supplier<DiscoveryNode>> seedNodes;
-            private final CancellableThreads cancellableThreads;
 
             SniffClusterStateResponseHandler(Transport.Connection connection, ActionListener<Void> listener,
-                                             Iterator<Supplier<DiscoveryNode>> seedNodes,
-                                             CancellableThreads cancellableThreads) {
+                                             Iterator<Supplier<DiscoveryNode>> seedNodes) {
                 this.connection = connection;
                 this.listener = listener;
                 this.seedNodes = seedNodes;
-                this.cancellableThreads = cancellableThreads;
             }
 
             @Override
@@ -582,43 +554,44 @@ final class RemoteClusterConnection implements TransportConnectionListener, Clos
 
             @Override
             public void handleResponse(ClusterStateResponse response) {
-                try {
-                    if (remoteClusterName.get() == null) {
-                        assert response.getClusterName().value() != null;
-                        remoteClusterName.set(response.getClusterName());
-                    }
-                    try (Closeable theConnection = connection) { // the connection is unused - see comment in #collectRemoteNodes
-                        // we have to close this connection before we notify listeners - this is mainly needed for test correctness
-                        // since if we do it afterwards we might fail assertions that check if all high level connections are closed.
-                        // from a code correctness perspective we could also close it afterwards. This try/with block will
-                        // maintain the possibly exceptions thrown from within the try block and suppress the ones that are possible thrown
-                        // by closing the connection
-                        cancellableThreads.executeIO(() -> {
-                            DiscoveryNodes nodes = response.getState().nodes();
-                            Iterable<DiscoveryNode> nodesIter = nodes.getNodes()::valuesIt;
-                            for (DiscoveryNode n : nodesIter) {
-                                DiscoveryNode node = maybeAddProxyAddress(proxyAddress, n);
-                                if (nodePredicate.test(node) && connectionManager.size() < maxNumRemoteConnections) {
-                                    try {
-                                        // noop if node is connected
-                                        PlainActionFuture.get(fut -> connectionManager.connectToNode(node, null,
-                                            transportService.connectionValidator(node), ActionListener.map(fut, x -> null)));
-                                    } catch (ConnectTransportException | IllegalStateException ex) {
+                handleNodes(response.getState().nodes().getNodes().valuesIt());
+            }
+
+            private void handleNodes(Iterator<DiscoveryNode> nodesIter) {
+                while (nodesIter.hasNext()) {
+                    final DiscoveryNode node = maybeAddProxyAddress(proxyAddress, nodesIter.next());
+                    if (nodePredicate.test(node) && connectionManager.size() < maxNumRemoteConnections) {
+                        connectionManager.connectToNode(node, null,
+                            transportService.connectionValidator(node), new ActionListener<>() {
+                                @Override
+                                public void onResponse(Void aVoid) {
+                                    handleNodes(nodesIter);
+                                }
+
+                                @Override
+                                public void onFailure(Exception e) {
+                                    if (e instanceof ConnectTransportException ||
+                                        e instanceof IllegalStateException) {
                                         // ISE if we fail the handshake with an version incompatible node
                                         // fair enough we can't connect just move on
-                                        logger.debug(() -> new ParameterizedMessage("failed to connect to node {}", node), ex);
+                                        logger.debug(() -> new ParameterizedMessage("failed to connect to node {}", node), e);
+                                        handleNodes(nodesIter);
+                                    } else {
+                                        logger.warn(() ->
+                                            new ParameterizedMessage("fetching nodes from external cluster {} failed", clusterAlias), e);
+                                        IOUtils.closeWhileHandlingException(connection);
+                                        collectRemoteNodes(seedNodes, listener);
                                     }
                                 }
-                            }
-                        });
+                            });
+                        return;
                     }
-                    listener.onResponse(null);
-                } catch (CancellableThreads.ExecutionCancelledException ex) {
-                    listener.onFailure(ex); // we got canceled - fail the listener and step out
-                } catch (Exception ex) {
-                    logger.warn(() -> new ParameterizedMessage("fetching nodes from external cluster {} failed", clusterAlias), ex);
-                    collectRemoteNodes(seedNodes, transportService, connectionManager, listener);
                 }
+                // We have to close this connection before we notify listeners - this is mainly needed for test correctness
+                // since if we do it afterwards we might fail assertions that check if all high level connections are closed.
+                // from a code correctness perspective we could also close it afterwards.
+                IOUtils.closeWhileHandlingException(connection);
+                listener.onResponse(null);
             }
 
             @Override
@@ -628,7 +601,7 @@ final class RemoteClusterConnection implements TransportConnectionListener, Clos
                     IOUtils.closeWhileHandlingException(connection);
                 } finally {
                     // once the connection is closed lets try the next node
-                    collectRemoteNodes(seedNodes, transportService, connectionManager, listener);
+                    collectRemoteNodes(seedNodes, listener);
                 }
             }
 
@@ -640,7 +613,9 @@ final class RemoteClusterConnection implements TransportConnectionListener, Clos
     }
 
     boolean assertNoRunningConnections() { // for testing only
-        assert connectHandler.running.availablePermits() == 1;
+        synchronized (connectHandler.mutex) {
+            assert connectHandler.listeners.isEmpty();
+        }
         return true;
     }
 

+ 2 - 28
server/src/test/java/org/elasticsearch/transport/RemoteClusterConnectionTests.java

@@ -47,7 +47,6 @@ import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.transport.TransportAddress;
 import org.elasticsearch.common.unit.TimeValue;
-import org.elasticsearch.common.util.CancellableThreads;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentFactory;
 import org.elasticsearch.core.internal.io.IOUtils;
@@ -431,20 +430,6 @@ public class RemoteClusterConnectionTests extends ESTestCase {
         ActionListener<Void> listener = ActionListener.wrap(
             x -> latch.countDown(),
             x -> {
-                /*
-                 * This can occur on a thread submitted to the thread pool while we are closing the
-                 * remote cluster connection at the end of the test.
-                 */
-                if (x instanceof CancellableThreads.ExecutionCancelledException) {
-                    try {
-                        // we should already be shutting down
-                        assertEquals(0L, latch.getCount());
-                    } finally {
-                        // ensure we count down the latch on failure as well to not prevent failing tests from ending
-                        latch.countDown();
-                    }
-                    return;
-                }
                 exceptionAtomicReference.set(x);
                 latch.countDown();
             }
@@ -579,7 +564,7 @@ public class RemoteClusterConnectionTests extends ESTestCase {
                 closeRemote.countDown();
                 listenerCalled.await();
                 assertNotNull(exceptionReference.get());
-                expectThrows(CancellableThreads.ExecutionCancelledException.class, () -> {
+                expectThrows(AlreadyClosedException.class, () -> {
                     throw exceptionReference.get();
                 });
 
@@ -639,16 +624,6 @@ public class RemoteClusterConnectionTests extends ESTestCase {
                                                 latch.countDown();
                                             },
                                             x -> {
-                                                /*
-                                                 * This can occur on a thread submitted to the thread pool while we are closing the
-                                                 * remote cluster connection at the end of the test.
-                                                 */
-                                                if (x instanceof CancellableThreads.ExecutionCancelledException) {
-                                                    // we should already be shutting down
-                                                    assertTrue(executed.get());
-                                                    return;
-                                                }
-
                                                 assertTrue(executed.compareAndSet(false, true));
                                                 latch.countDown();
 
@@ -736,8 +711,7 @@ public class RemoteClusterConnectionTests extends ESTestCase {
                                                         throw assertionError;
                                                     }
                                                 }
-                                                if (x instanceof RejectedExecutionException || x instanceof AlreadyClosedException
-                                                    || x instanceof CancellableThreads.ExecutionCancelledException) {
+                                                if (x instanceof RejectedExecutionException || x instanceof AlreadyClosedException) {
                                                     // that's fine
                                                 } else {
                                                     throw new AssertionError(x);