浏览代码

Move ConnectionManager to async APIs (#42636)

This commit converts the ConnectionManager's openConnection and connectToNode methods to
async-style. This will allow us to not block threads anymore when opening connections. This PR also
adapts the cluster coordination subsystem to make use of the new async APIs, allowing to remove
some hacks in the test infrastructure that had to account for the previous synchronous nature of the
connection APIs.
Yannick Welsch 6 年之前
父节点
当前提交
bca865dd42
共有 19 个文件被更改,包括 549 次插入322 次删除
  1. 8 0
      server/src/main/java/org/elasticsearch/action/support/PlainActionFuture.java
  2. 15 16
      server/src/main/java/org/elasticsearch/cluster/coordination/Coordinator.java
  3. 63 32
      server/src/main/java/org/elasticsearch/discovery/HandshakingTransportAddressConnector.java
  4. 141 95
      server/src/main/java/org/elasticsearch/transport/ConnectionManager.java
  5. 12 7
      server/src/main/java/org/elasticsearch/transport/RemoteClusterConnection.java
  6. 6 5
      server/src/main/java/org/elasticsearch/transport/TcpTransport.java
  7. 3 6
      server/src/main/java/org/elasticsearch/transport/Transport.java
  8. 89 43
      server/src/main/java/org/elasticsearch/transport/TransportService.java
  9. 4 6
      server/src/test/java/org/elasticsearch/cluster/NodeConnectionsServiceTests.java
  10. 1 1
      server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java
  11. 126 9
      server/src/test/java/org/elasticsearch/transport/ConnectionManagerTests.java
  12. 25 25
      server/src/test/java/org/elasticsearch/transport/RemoteClusterConnectionTests.java
  13. 5 3
      server/src/test/java/org/elasticsearch/transport/TransportServiceHandshakeTests.java
  14. 23 40
      test/framework/src/main/java/org/elasticsearch/test/disruption/DisruptableMockTransport.java
  15. 1 3
      test/framework/src/main/java/org/elasticsearch/test/transport/MockTransport.java
  16. 9 17
      test/framework/src/main/java/org/elasticsearch/test/transport/MockTransportService.java
  17. 5 6
      test/framework/src/main/java/org/elasticsearch/test/transport/StubbableConnectionManager.java
  18. 5 6
      test/framework/src/main/java/org/elasticsearch/test/transport/StubbableTransport.java
  19. 8 2
      test/framework/src/test/java/org/elasticsearch/test/disruption/DisruptableMockTransportTests.java

+ 8 - 0
server/src/main/java/org/elasticsearch/action/support/PlainActionFuture.java

@@ -19,12 +19,20 @@
 
 package org.elasticsearch.action.support;
 
+import org.elasticsearch.common.CheckedConsumer;
+
 public class PlainActionFuture<T> extends AdapterActionFuture<T, T> {
 
     public static <T> PlainActionFuture<T> newFuture() {
         return new PlainActionFuture<>();
     }
 
+    public static <T, E extends Exception> T get(CheckedConsumer<PlainActionFuture<T>, E> e) throws E {
+        PlainActionFuture<T> fut = newFuture();
+        e.accept(fut);
+        return fut.actionGet();
+    }
+
     @Override
     protected T convert(T listenerResponse) {
         return listenerResponse;

+ 15 - 16
server/src/main/java/org/elasticsearch/cluster/coordination/Coordinator.java

@@ -442,23 +442,22 @@ public class Coordinator extends AbstractLifecycleComponent implements Discovery
             return;
         }
 
-        transportService.connectToNode(joinRequest.getSourceNode());
-
-        final ClusterState stateForJoinValidation = getStateForMasterService();
-
-        if (stateForJoinValidation.nodes().isLocalNodeElectedMaster()) {
-            onJoinValidators.forEach(a -> a.accept(joinRequest.getSourceNode(), stateForJoinValidation));
-            if (stateForJoinValidation.getBlocks().hasGlobalBlock(STATE_NOT_RECOVERED_BLOCK) == false) {
-                // we do this in a couple of places including the cluster update thread. This one here is really just best effort
-                // to ensure we fail as fast as possible.
-                JoinTaskExecutor.ensureMajorVersionBarrier(joinRequest.getSourceNode().getVersion(),
-                    stateForJoinValidation.getNodes().getMinNodeVersion());
+        transportService.connectToNode(joinRequest.getSourceNode(), ActionListener.wrap(ignore -> {
+            final ClusterState stateForJoinValidation = getStateForMasterService();
+
+            if (stateForJoinValidation.nodes().isLocalNodeElectedMaster()) {
+                onJoinValidators.forEach(a -> a.accept(joinRequest.getSourceNode(), stateForJoinValidation));
+                if (stateForJoinValidation.getBlocks().hasGlobalBlock(STATE_NOT_RECOVERED_BLOCK) == false) {
+                    // we do this in a couple of places including the cluster update thread. This one here is really just best effort
+                    // to ensure we fail as fast as possible.
+                    JoinTaskExecutor.ensureMajorVersionBarrier(joinRequest.getSourceNode().getVersion(),
+                        stateForJoinValidation.getNodes().getMinNodeVersion());
+                }
+                sendValidateJoinRequest(stateForJoinValidation, joinRequest, joinCallback);
+            } else {
+                processJoinRequest(joinRequest, joinCallback);
             }
-            sendValidateJoinRequest(stateForJoinValidation, joinRequest, joinCallback);
-
-        } else {
-            processJoinRequest(joinRequest, joinCallback);
-        }
+        }, joinCallback::onFailure));
     }
 
     // package private for tests

+ 63 - 32
server/src/main/java/org/elasticsearch/discovery/HandshakingTransportAddressConnector.java

@@ -24,6 +24,7 @@ import org.apache.logging.log4j.Logger;
 import org.apache.logging.log4j.message.ParameterizedMessage;
 import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.NotifyOnceListener;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.common.Randomness;
 import org.elasticsearch.common.UUIDs;
@@ -70,7 +71,7 @@ public class HandshakingTransportAddressConnector implements TransportAddressCon
     public void connectToRemoteMasterNode(TransportAddress transportAddress, ActionListener<DiscoveryNode> listener) {
         transportService.getThreadPool().generic().execute(new AbstractRunnable() {
             @Override
-            protected void doRun() throws Exception {
+            protected void doRun() {
 
                 // TODO if transportService is already connected to this address then skip the handshaking
 
@@ -80,38 +81,68 @@ public class HandshakingTransportAddressConnector implements TransportAddressCon
                     emptySet(), Version.CURRENT.minimumCompatibilityVersion());
 
                 logger.trace("[{}] opening probe connection", this);
-                final Connection connection = transportService.openConnection(targetNode,
+                transportService.openConnection(targetNode,
                     ConnectionProfile.buildSingleChannelProfile(Type.REG, probeConnectTimeout, probeHandshakeTimeout,
-                        TimeValue.MINUS_ONE, null));
-                logger.trace("[{}] opened probe connection", this);
-
-                final DiscoveryNode remoteNode;
-                try {
-                    remoteNode = transportService.handshake(connection, probeHandshakeTimeout.millis());
-                    // success means (amongst other things) that the cluster names match
-                    logger.trace("[{}] handshake successful: {}", this, remoteNode);
-                } catch (Exception e) {
-                    // we opened a connection and successfully performed a low-level handshake, so we were definitely talking to an
-                    // Elasticsearch node, but the high-level handshake failed indicating some kind of mismatched configurations
-                    // (e.g. cluster name) that the user should address
-                    logger.warn(new ParameterizedMessage("handshake failed for [{}]", this), e);
-                    listener.onFailure(e);
-                    return;
-                } finally {
-                    IOUtils.closeWhileHandlingException(connection);
-                }
-
-                if (remoteNode.equals(transportService.getLocalNode())) {
-                    // TODO cache this result for some time? forever?
-                    listener.onFailure(new ConnectTransportException(remoteNode, "local node found"));
-                } else if (remoteNode.isMasterNode() == false) {
-                    // TODO cache this result for some time?
-                    listener.onFailure(new ConnectTransportException(remoteNode, "non-master-eligible node found"));
-                } else {
-                    transportService.connectToNode(remoteNode);
-                    logger.trace("[{}] full connection successful: {}", this, remoteNode);
-                    listener.onResponse(remoteNode);
-                }
+                        TimeValue.MINUS_ONE, null), new ActionListener<>() {
+                        @Override
+                        public void onResponse(Connection connection) {
+                            logger.trace("[{}] opened probe connection", this);
+
+                            // use NotifyOnceListener to make sure the following line does not result in onFailure being called when
+                            // the connection is closed in the onResponse handler
+                            transportService.handshake(connection, probeHandshakeTimeout.millis(), new NotifyOnceListener<DiscoveryNode>() {
+
+                                @Override
+                                protected void innerOnResponse(DiscoveryNode remoteNode) {
+                                    try {
+                                        // success means (amongst other things) that the cluster names match
+                                        logger.trace("[{}] handshake successful: {}", this, remoteNode);
+                                        IOUtils.closeWhileHandlingException(connection);
+
+                                        if (remoteNode.equals(transportService.getLocalNode())) {
+                                            // TODO cache this result for some time? forever?
+                                            listener.onFailure(new ConnectTransportException(remoteNode, "local node found"));
+                                        } else if (remoteNode.isMasterNode() == false) {
+                                            // TODO cache this result for some time?
+                                            listener.onFailure(new ConnectTransportException(remoteNode, "non-master-eligible node found"));
+                                        } else {
+                                            transportService.connectToNode(remoteNode, new ActionListener<Void>() {
+                                                @Override
+                                                public void onResponse(Void ignored) {
+                                                    logger.trace("[{}] full connection successful: {}", this, remoteNode);
+                                                    listener.onResponse(remoteNode);
+                                                }
+
+                                                @Override
+                                                public void onFailure(Exception e) {
+                                                    listener.onFailure(e);
+                                                }
+                                            });
+                                        }
+                                    } catch (Exception e) {
+                                        listener.onFailure(e);
+                                    }
+                                }
+
+                                @Override
+                                protected void innerOnFailure(Exception e) {
+                                    // we opened a connection and successfully performed a low-level handshake, so we were definitely
+                                    // talking to an Elasticsearch node, but the high-level handshake failed indicating some kind of
+                                    // mismatched configurations (e.g. cluster name) that the user should address
+                                    logger.warn(new ParameterizedMessage("handshake failed for [{}]", this), e);
+                                    IOUtils.closeWhileHandlingException(connection);
+                                    listener.onFailure(e);
+                                }
+
+                            });
+
+                        }
+
+                        @Override
+                        public void onFailure(Exception e) {
+                            listener.onFailure(e);
+                        }
+                    });
             }
 
             @Override

+ 141 - 95
server/src/main/java/org/elasticsearch/transport/ConnectionManager.java

@@ -22,24 +22,24 @@ import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.apache.logging.log4j.message.ParameterizedMessage;
 import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.cluster.node.DiscoveryNode;
-import org.elasticsearch.common.CheckedBiConsumer;
 import org.elasticsearch.common.lease.Releasable;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.util.concurrent.AbstractRefCounted;
 import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
 import org.elasticsearch.common.util.concurrent.KeyedLock;
+import org.elasticsearch.common.util.concurrent.RunOnce;
 import org.elasticsearch.core.internal.io.IOUtils;
 
 import java.io.Closeable;
-import java.io.IOException;
+import java.util.ArrayList;
 import java.util.Iterator;
+import java.util.List;
 import java.util.Map;
 import java.util.concurrent.ConcurrentMap;
 import java.util.concurrent.CopyOnWriteArrayList;
+import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.atomic.AtomicBoolean;
-import java.util.concurrent.locks.ReadWriteLock;
-import java.util.concurrent.locks.ReentrantReadWriteLock;
 
 /**
  * This class manages node connections. The connection is opened by the underlying transport. Once the
@@ -51,11 +51,18 @@ public class ConnectionManager implements Closeable {
     private static final Logger logger = LogManager.getLogger(ConnectionManager.class);
 
     private final ConcurrentMap<DiscoveryNode, Transport.Connection> connectedNodes = ConcurrentCollections.newConcurrentMap();
-    private final KeyedLock<String> connectionLock = new KeyedLock<>();
+    private final KeyedLock<String> connectionLock = new KeyedLock<>(); // protects concurrent access to connectingNodes
+    private final Map<DiscoveryNode, List<ActionListener<Void>>> connectingNodes = ConcurrentCollections.newConcurrentMap();
+    private final AbstractRefCounted connectingRefCounter = new AbstractRefCounted("connection manager") {
+        @Override
+        protected void closeInternal() {
+            closeLatch.countDown();
+        }
+    };
     private final Transport transport;
     private final ConnectionProfile defaultProfile;
-    private final AtomicBoolean isClosed = new AtomicBoolean(false);
-    private final ReadWriteLock closeLock = new ReentrantReadWriteLock();
+    private final AtomicBoolean closing = new AtomicBoolean(false);
+    private final CountDownLatch closeLatch = new CountDownLatch(1);
     private final DelegatingNodeConnectionListener connectionListener = new DelegatingNodeConnectionListener();
 
     public ConnectionManager(Settings settings, Transport transport) {
@@ -75,66 +82,119 @@ public class ConnectionManager implements Closeable {
         this.connectionListener.listeners.remove(listener);
     }
 
-    public Transport.Connection openConnection(DiscoveryNode node, ConnectionProfile connectionProfile) {
+    public void openConnection(DiscoveryNode node, ConnectionProfile connectionProfile, ActionListener<Transport.Connection> listener) {
         ConnectionProfile resolvedProfile = ConnectionProfile.resolveConnectionProfile(connectionProfile, defaultProfile);
-        return internalOpenConnection(node, resolvedProfile);
+        internalOpenConnection(node, resolvedProfile, listener);
+    }
+
+    @FunctionalInterface
+    public interface ConnectionValidator {
+        void validate(Transport.Connection connection, ConnectionProfile profile, ActionListener<Void> listener);
     }
 
     /**
      * Connects to a node with the given connection profile. If the node is already connected this method has no effect.
      * Once a successful is established, it can be validated before being exposed.
+     * The ActionListener will be called on the calling thread or the generic thread pool.
      */
     public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile,
-                              CheckedBiConsumer<Transport.Connection, ConnectionProfile, IOException> connectionValidator)
-        throws ConnectTransportException {
+                              ConnectionValidator connectionValidator,
+                              ActionListener<Void> listener) throws ConnectTransportException {
         ConnectionProfile resolvedProfile = ConnectionProfile.resolveConnectionProfile(connectionProfile, defaultProfile);
         if (node == null) {
-            throw new ConnectTransportException(null, "can't connect to a null node");
+            listener.onFailure(new ConnectTransportException(null, "can't connect to a null node"));
+            return;
         }
-        closeLock.readLock().lock(); // ensure we don't open connections while we are closing
-        try {
-            ensureOpen();
-            try (Releasable ignored = connectionLock.acquire(node.getId())) {
-                Transport.Connection connection = connectedNodes.get(node);
-                if (connection != null) {
-                    return;
-                }
-                boolean success = false;
-                try {
-                    connection = internalOpenConnection(node, resolvedProfile);
-                    connectionValidator.accept(connection, resolvedProfile);
-                    // we acquire a connection lock, so no way there is an existing connection
-                    connectedNodes.put(node, connection);
-                    if (logger.isDebugEnabled()) {
-                        logger.debug("connected to node [{}]", node);
-                    }
+
+        if (connectingRefCounter.tryIncRef() == false) {
+            listener.onFailure(new IllegalStateException("connection manager is closed"));
+            return;
+        }
+
+        try (Releasable lock = connectionLock.acquire(node.getId())) {
+            Transport.Connection connection = connectedNodes.get(node);
+            if (connection != null) {
+                assert connectingNodes.containsKey(node) == false;
+                lock.close();
+                connectingRefCounter.decRef();
+                listener.onResponse(null);
+                return;
+            }
+
+            final List<ActionListener<Void>> connectionListeners = connectingNodes.computeIfAbsent(node, n -> new ArrayList<>());
+            connectionListeners.add(listener);
+            if (connectionListeners.size() > 1) {
+                // wait on previous entry to complete connection attempt
+                connectingRefCounter.decRef();
+                return;
+            }
+        }
+
+        final RunOnce releaseOnce = new RunOnce(connectingRefCounter::decRef);
+
+        internalOpenConnection(node, resolvedProfile, ActionListener.wrap(conn -> {
+            connectionValidator.validate(conn, resolvedProfile, ActionListener.wrap(
+                ignored -> {
+                    assert Transports.assertNotTransportThread("connection validator success");
+                    boolean success = false;
+                    List<ActionListener<Void>> listeners = null;
                     try {
-                        connectionListener.onNodeConnected(node);
+                        // we acquire a connection lock, so no way there is an existing connection
+                        try (Releasable ignored2 = connectionLock.acquire(node.getId())) {
+                            connectedNodes.put(node, conn);
+                            if (logger.isDebugEnabled()) {
+                                logger.debug("connected to node [{}]", node);
+                            }
+                            try {
+                                connectionListener.onNodeConnected(node);
+                            } finally {
+                                final Transport.Connection finalConnection = conn;
+                                conn.addCloseListener(ActionListener.wrap(() -> {
+                                    logger.trace("unregistering {} after connection close and marking as disconnected", node);
+                                    connectedNodes.remove(node, finalConnection);
+                                    connectionListener.onNodeDisconnected(node);
+                                }));
+                            }
+                            if (conn.isClosed()) {
+                                throw new NodeNotConnectedException(node, "connection concurrently closed");
+                            }
+                            success = true;
+                            listeners = connectingNodes.remove(node);
+                        }
+                    } catch (ConnectTransportException e) {
+                        throw e;
+                    } catch (Exception e) {
+                        throw new ConnectTransportException(node, "general node connection failure", e);
                     } finally {
-                        final Transport.Connection finalConnection = connection;
-                        connection.addCloseListener(ActionListener.wrap(() -> {
-                            connectedNodes.remove(node, finalConnection);
-                            connectionListener.onNodeDisconnected(node);
-                        }));
+                        if (success == false) { // close the connection if there is a failure
+                            logger.trace(() -> new ParameterizedMessage("failed to connect to [{}], cleaning dangling connections", node));
+                            IOUtils.closeWhileHandlingException(conn);
+                        } else {
+                            releaseOnce.run();
+                            ActionListener.onResponse(listeners, null);
+                        }
                     }
-                    if (connection.isClosed()) {
-                        throw new NodeNotConnectedException(node, "connection concurrently closed");
+                }, e -> {
+                    assert Transports.assertNotTransportThread("connection validator failure");
+                    IOUtils.closeWhileHandlingException(conn);
+                    final List<ActionListener<Void>> listeners;
+                    try (Releasable ignored = connectionLock.acquire(node.getId())) {
+                        listeners = connectingNodes.remove(node);
                     }
-                    success = true;
-                } catch (ConnectTransportException e) {
-                    throw e;
-                } catch (Exception e) {
-                    throw new ConnectTransportException(node, "general node connection failure", e);
-                } finally {
-                    if (success == false) { // close the connection if there is a failure
-                        logger.trace(() -> new ParameterizedMessage("failed to connect to [{}], cleaning dangling connections", node));
-                        IOUtils.closeWhileHandlingException(connection);
-                    }
-                }
+                    releaseOnce.run();
+                    ActionListener.onFailure(listeners, e);
+                }));
+        }, e -> {
+            assert Transports.assertNotTransportThread("internalOpenConnection failure");
+            final List<ActionListener<Void>> listeners;
+            try (Releasable ignored = connectionLock.acquire(node.getId())) {
+                listeners = connectingNodes.remove(node);
             }
-        } finally {
-            closeLock.readLock().unlock();
-        }
+            releaseOnce.run();
+            if (listeners != null) {
+                ActionListener.onFailure(listeners, e);
+            }
+        }));
     }
 
     /**
@@ -143,7 +203,7 @@ public class ConnectionManager implements Closeable {
      * maintained by this connection manager
      *
      * @throws NodeNotConnectedException if the node is not connected
-     * @see #connectToNode(DiscoveryNode, ConnectionProfile, CheckedBiConsumer)
+     * @see #connectToNode(DiscoveryNode, ConnectionProfile, ConnectionValidator, ActionListener)
      */
     public Transport.Connection getConnection(DiscoveryNode node) {
         Transport.Connection connection = connectedNodes.get(node);
@@ -180,55 +240,41 @@ public class ConnectionManager implements Closeable {
 
     @Override
     public void close() {
-        Transports.assertNotTransportThread("Closing ConnectionManager");
-        if (isClosed.compareAndSet(false, true)) {
-            closeLock.writeLock().lock();
+        assert Transports.assertNotTransportThread("Closing ConnectionManager");
+        if (closing.compareAndSet(false, true)) {
+            connectingRefCounter.decRef();
             try {
-                // we are holding a write lock so nobody adds to the connectedNodes / openConnections map - it's safe to first close
-                // all instances and then clear them maps
-                Iterator<Map.Entry<DiscoveryNode, Transport.Connection>> iterator = connectedNodes.entrySet().iterator();
-                while (iterator.hasNext()) {
-                    Map.Entry<DiscoveryNode, Transport.Connection> next = iterator.next();
-                    try {
-                        IOUtils.closeWhileHandlingException(next.getValue());
-                    } finally {
-                        iterator.remove();
-                    }
+                closeLatch.await();
+            } catch (InterruptedException e) {
+                Thread.currentThread().interrupt();
+                throw new IllegalStateException(e);
+            }
+            Iterator<Map.Entry<DiscoveryNode, Transport.Connection>> iterator = connectedNodes.entrySet().iterator();
+            while (iterator.hasNext()) {
+                Map.Entry<DiscoveryNode, Transport.Connection> next = iterator.next();
+                try {
+                    IOUtils.closeWhileHandlingException(next.getValue());
+                } finally {
+                    iterator.remove();
                 }
-            } finally {
-                closeLock.writeLock().unlock();
             }
         }
     }
 
-    private Transport.Connection internalOpenConnection(DiscoveryNode node, ConnectionProfile connectionProfile) {
-        PlainActionFuture<Transport.Connection> future = PlainActionFuture.newFuture();
-        Releasable pendingConnection = transport.openConnection(node, connectionProfile, future);
-        Transport.Connection connection;
-        try {
-            connection = future.actionGet();
-        } catch (IllegalStateException e) {
-            // If the future was interrupted we must cancel the pending connection to avoid channels leaking
-            if (e.getCause() instanceof InterruptedException) {
-                pendingConnection.close();
+    private void internalOpenConnection(DiscoveryNode node, ConnectionProfile connectionProfile,
+                                        ActionListener<Transport.Connection> listener) {
+        transport.openConnection(node, connectionProfile, ActionListener.map(listener, connection -> {
+            assert Transports.assertNotTransportThread("internalOpenConnection success");
+            try {
+                connectionListener.onConnectionOpened(connection);
+            } finally {
+                connection.addCloseListener(ActionListener.wrap(() -> connectionListener.onConnectionClosed(connection)));
             }
-            throw e;
-        }
-        try {
-            connectionListener.onConnectionOpened(connection);
-        } finally {
-            connection.addCloseListener(ActionListener.wrap(() -> connectionListener.onConnectionClosed(connection)));
-        }
-        if (connection.isClosed()) {
-            throw new ConnectTransportException(node, "a channel closed while connecting");
-        }
-        return connection;
-    }
-
-    private void ensureOpen() {
-        if (isClosed.get()) {
-            throw new IllegalStateException("connection manager is closed");
-        }
+            if (connection.isClosed()) {
+                throw new ConnectTransportException(node, "a channel closed while connecting");
+            }
+            return connection;
+        }));
     }
 
     ConnectionProfile getConnectionProfile() {

+ 12 - 7
server/src/main/java/org/elasticsearch/transport/RemoteClusterConnection.java

@@ -29,6 +29,7 @@ import org.elasticsearch.action.admin.cluster.state.ClusterStateAction;
 import org.elasticsearch.action.admin.cluster.state.ClusterStateRequest;
 import org.elasticsearch.action.admin.cluster.state.ClusterStateResponse;
 import org.elasticsearch.action.support.ContextPreservingActionListener;
+import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.cluster.ClusterName;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNodes;
@@ -448,14 +449,16 @@ final class RemoteClusterConnection implements TransportConnectionListener, Clos
                         logger.debug("[{}] opening connection to seed node: [{}] proxy address: [{}]", clusterAlias, seedNode,
                             proxyAddress);
                         final TransportService.HandshakeResponse handshakeResponse;
-                        ConnectionProfile profile = ConnectionProfile.buildSingleChannelProfile(TransportRequestOptions.Type.REG);
-                        Transport.Connection connection = manager.openConnection(seedNode, profile);
+                        final ConnectionProfile profile = ConnectionProfile.buildSingleChannelProfile(TransportRequestOptions.Type.REG);
+                        final Transport.Connection connection = PlainActionFuture.get(
+                            fut -> manager.openConnection(seedNode, profile, fut));
                         boolean success = false;
                         try {
                             try {
                                 ConnectionProfile connectionProfile = connectionManager.getConnectionProfile();
-                                handshakeResponse = transportService.handshake(connection, connectionProfile.getHandshakeTimeout().millis(),
-                                    (c) -> remoteClusterName.get() == null ? true : c.equals(remoteClusterName.get()));
+                                handshakeResponse = PlainActionFuture.get(fut ->
+                                    transportService.handshake(connection, connectionProfile.getHandshakeTimeout().millis(),
+                                        (c) -> remoteClusterName.get() == null ? true : c.equals(remoteClusterName.get()), fut));
                             } catch (IllegalStateException ex) {
                                 logger.warn(() -> new ParameterizedMessage("seed node {} cluster name mismatch expected " +
                                     "cluster name {}", connection.getNode(), remoteClusterName.get()), ex);
@@ -464,7 +467,8 @@ final class RemoteClusterConnection implements TransportConnectionListener, Clos
 
                             final DiscoveryNode handshakeNode = maybeAddProxyAddress(proxyAddress, handshakeResponse.getDiscoveryNode());
                             if (nodePredicate.test(handshakeNode) && connectedNodes.size() < maxNumRemoteConnections) {
-                                manager.connectToNode(handshakeNode, null, transportService.connectionValidator(handshakeNode));
+                                PlainActionFuture.get(fut -> manager.connectToNode(handshakeNode, null,
+                                    transportService.connectionValidator(handshakeNode), ActionListener.map(fut, x -> null)));
                                 if (remoteClusterName.get() == null) {
                                     assert handshakeResponse.getClusterName().value() != null;
                                     remoteClusterName.set(handshakeResponse.getClusterName());
@@ -578,8 +582,9 @@ final class RemoteClusterConnection implements TransportConnectionListener, Clos
                                 DiscoveryNode node = maybeAddProxyAddress(proxyAddress, n);
                                 if (nodePredicate.test(node) && connectedNodes.size() < maxNumRemoteConnections) {
                                     try {
-                                        connectionManager.connectToNode(node, null,
-                                            transportService.connectionValidator(node)); // noop if node is connected
+                                        // noop if node is connected
+                                        PlainActionFuture.get(fut -> connectionManager.connectToNode(node, null,
+                                            transportService.connectionValidator(node), ActionListener.map(fut, x -> null)));
                                         connectedNodes.add(node);
                                     } catch (ConnectTransportException | IllegalStateException ex) {
                                         // ISE if we fail the handshake with an version incompatible node

+ 6 - 5
server/src/main/java/org/elasticsearch/transport/TcpTransport.java

@@ -26,6 +26,7 @@ import org.apache.logging.log4j.message.ParameterizedMessage;
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.support.ThreadedActionListener;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.common.Booleans;
 import org.elasticsearch.common.Strings;
@@ -35,7 +36,6 @@ import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.component.AbstractLifecycleComponent;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.StreamInput;
-import org.elasticsearch.common.lease.Releasable;
 import org.elasticsearch.common.metrics.MeanMetric;
 import org.elasticsearch.common.network.CloseableChannel;
 import org.elasticsearch.common.network.NetworkAddress;
@@ -254,7 +254,8 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
     }
 
     @Override
-    public Releasable openConnection(DiscoveryNode node, ConnectionProfile profile, ActionListener<Transport.Connection> listener) {
+    public void openConnection(DiscoveryNode node, ConnectionProfile profile, ActionListener<Transport.Connection> listener) {
+
         Objects.requireNonNull(profile, "connection profile cannot be null");
         if (node == null) {
             throw new ConnectTransportException(null, "can't open connection to a null node");
@@ -263,8 +264,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
         closeLock.readLock().lock(); // ensure we don't open connections while we are closing
         try {
             ensureOpen();
-            List<TcpChannel> pendingChannels = initiateConnection(node, finalProfile, listener);
-            return () -> CloseableChannel.closeChannels(pendingChannels, false);
+            initiateConnection(node, finalProfile, listener);
         } finally {
             closeLock.readLock().unlock();
         }
@@ -293,7 +293,8 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
             }
         }
 
-        ChannelsConnectedListener channelsConnectedListener = new ChannelsConnectedListener(node, connectionProfile, channels, listener);
+        ChannelsConnectedListener channelsConnectedListener = new ChannelsConnectedListener(node, connectionProfile, channels,
+            new ThreadedActionListener<>(logger, threadPool, ThreadPool.Names.GENERIC, listener, false));
 
         for (TcpChannel channel : channels) {
             channel.addConnectListener(channelsConnectedListener);

+ 3 - 6
server/src/main/java/org/elasticsearch/transport/Transport.java

@@ -25,7 +25,6 @@ import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.common.breaker.CircuitBreaker;
 import org.elasticsearch.common.breaker.NoopCircuitBreaker;
 import org.elasticsearch.common.component.LifecycleComponent;
-import org.elasticsearch.common.lease.Releasable;
 import org.elasticsearch.common.transport.BoundTransportAddress;
 import org.elasticsearch.common.transport.TransportAddress;
 import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
@@ -80,12 +79,10 @@ public interface Transport extends LifecycleComponent {
     }
 
     /**
-     * Opens a new connection to the given node. When the connection is fully connected, the listener is
-     * called. A {@link Releasable} is returned representing the pending connection. If the caller of this
-     * method decides to move on before the listener is called with the completed connection, they should
-     * release the pending connection to prevent hanging connections.
+     * Opens a new connection to the given node. When the connection is fully connected, the listener is called.
+     * The ActionListener will be called on the calling thread or the generic thread pool.
      */
-    Releasable openConnection(DiscoveryNode node, ConnectionProfile profile, ActionListener<Transport.Connection> listener);
+    void openConnection(DiscoveryNode node, ConnectionProfile profile, ActionListener<Transport.Connection> listener);
 
     TransportStats getStats();
 

+ 89 - 43
server/src/main/java/org/elasticsearch/transport/TransportService.java

@@ -24,9 +24,10 @@ import org.apache.logging.log4j.Logger;
 import org.apache.logging.log4j.message.ParameterizedMessage;
 import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.ActionListenerResponseHandler;
+import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.cluster.ClusterName;
 import org.elasticsearch.cluster.node.DiscoveryNode;
-import org.elasticsearch.common.CheckedBiConsumer;
 import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.component.AbstractLifecycleComponent;
@@ -321,7 +322,7 @@ public class TransportService extends AbstractLifecycleComponent implements Tran
      * @param node the node to connect to
      */
     public void connectToNode(DiscoveryNode node) throws ConnectTransportException {
-        connectToNode(node, null);
+        connectToNode(node, (ConnectionProfile) null);
     }
 
     /**
@@ -331,34 +332,74 @@ public class TransportService extends AbstractLifecycleComponent implements Tran
      * @param connectionProfile the connection profile to use when connecting to this node
      */
     public void connectToNode(final DiscoveryNode node, ConnectionProfile connectionProfile) {
+        PlainActionFuture.get(fut -> connectToNode(node, connectionProfile, ActionListener.map(fut, x -> null)));
+    }
+
+    /**
+     * Connect to the specified node with the given connection profile.
+     * The ActionListener will be called on the calling thread or the generic thread pool.
+     *
+     * @param node the node to connect to
+     * @param listener the action listener to notify
+     */
+    public void connectToNode(DiscoveryNode node, ActionListener<Void> listener) throws ConnectTransportException {
+        connectToNode(node, null, listener);
+    }
+
+    /**
+     * Connect to the specified node with the given connection profile.
+     * The ActionListener will be called on the calling thread or the generic thread pool.
+     *
+     * @param node the node to connect to
+     * @param connectionProfile the connection profile to use when connecting to this node
+     * @param listener the action listener to notify
+     */
+    public void connectToNode(final DiscoveryNode node, ConnectionProfile connectionProfile, ActionListener<Void> listener) {
         if (isLocalNode(node)) {
+            listener.onResponse(null);
             return;
         }
-        connectionManager.connectToNode(node, connectionProfile, connectionValidator(node));
+        connectionManager.connectToNode(node, connectionProfile, connectionValidator(node), listener);
     }
 
-    public CheckedBiConsumer<Transport.Connection, ConnectionProfile, IOException> connectionValidator(DiscoveryNode node) {
-        return (newConnection, actualProfile) -> {
+    public ConnectionManager.ConnectionValidator connectionValidator(DiscoveryNode node) {
+        return (newConnection, actualProfile, listener) -> {
             // We don't validate cluster names to allow for CCS connections.
-            final DiscoveryNode remote = handshake(newConnection, actualProfile.getHandshakeTimeout().millis(), cn -> true).discoveryNode;
-            if (node.equals(remote) == false) {
-                throw new ConnectTransportException(node, "handshake failed. unexpected remote node " + remote);
-            }
+            handshake(newConnection, actualProfile.getHandshakeTimeout().millis(), cn -> true, ActionListener.map(listener, resp -> {
+                final DiscoveryNode remote = resp.discoveryNode;
+                if (node.equals(remote) == false) {
+                    throw new ConnectTransportException(node, "handshake failed. unexpected remote node " + remote);
+                }
+                return null;
+            }));
         };
-
     }
 
     /**
      * Establishes and returns a new connection to the given node. The connection is NOT maintained by this service, it's the callers
      * responsibility to close the connection once it goes out of scope.
+     * The ActionListener will be called on the calling thread or the generic thread pool.
+     * @param node the node to connect to
+     * @param connectionProfile the connection profile to use
+     */
+    public Transport.Connection openConnection(final DiscoveryNode node, ConnectionProfile connectionProfile) {
+        return PlainActionFuture.get(fut -> openConnection(node, connectionProfile, fut));
+    }
+
+    /**
+     * Establishes a new connection to the given node. The connection is NOT maintained by this service, it's the callers
+     * responsibility to close the connection once it goes out of scope.
+     * The ActionListener will be called on the calling thread or the generic thread pool.
      * @param node the node to connect to
      * @param connectionProfile the connection profile to use
+     * @param listener the action listener to notify
      */
-    public Transport.Connection openConnection(final DiscoveryNode node, ConnectionProfile connectionProfile) throws IOException {
+    public void openConnection(final DiscoveryNode node, ConnectionProfile connectionProfile,
+                               ActionListener<Transport.Connection> listener) {
         if (isLocalNode(node)) {
-            return localNodeConnection;
+            listener.onResponse(localNodeConnection);
         } else {
-            return connectionManager.openConnection(node, connectionProfile);
+            connectionManager.openConnection(node, connectionProfile, listener);
         }
     }
 
@@ -367,17 +408,19 @@ public class TransportService extends AbstractLifecycleComponent implements Tran
      * and returns the discovery node of the node the connection
      * was established with. The handshake will fail if the cluster
      * name on the target node mismatches the local cluster name.
+     * The ActionListener will be called on the calling thread or the generic thread pool.
      *
      * @param connection       the connection to a specific node
      * @param handshakeTimeout handshake timeout
-     * @return the connected node
+     * @param listener         action listener to notify
      * @throws ConnectTransportException if the connection failed
      * @throws IllegalStateException if the handshake failed
      */
-    public DiscoveryNode handshake(
-            final Transport.Connection connection,
-            final long handshakeTimeout) throws ConnectTransportException {
-        return handshake(connection, handshakeTimeout, clusterName::equals).discoveryNode;
+    public void handshake(
+        final Transport.Connection connection,
+        final long handshakeTimeout,
+        final ActionListener<DiscoveryNode> listener) {
+        handshake(connection, handshakeTimeout, clusterName::equals, ActionListener.map(listener, HandshakeResponse::getDiscoveryNode));
     }
 
     /**
@@ -385,40 +428,43 @@ public class TransportService extends AbstractLifecycleComponent implements Tran
      * and returns the discovery node of the node the connection
      * was established with. The handshake will fail if the cluster
      * name on the target node doesn't match the local cluster name.
+     * The ActionListener will be called on the calling thread or the generic thread pool.
      *
      * @param connection       the connection to a specific node
      * @param handshakeTimeout handshake timeout
      * @param clusterNamePredicate cluster name validation predicate
-     * @return the handshake response
+     * @param listener         action listener to notify
      * @throws IllegalStateException if the handshake failed
      */
-    public HandshakeResponse handshake(
+    public void handshake(
         final Transport.Connection connection,
-        final long handshakeTimeout, Predicate<ClusterName> clusterNamePredicate) {
-        final HandshakeResponse response;
+        final long handshakeTimeout, Predicate<ClusterName> clusterNamePredicate,
+        final ActionListener<HandshakeResponse> listener) {
         final DiscoveryNode node = connection.getNode();
-        try {
-            PlainTransportFuture<HandshakeResponse> futureHandler = new PlainTransportFuture<>(
-                new FutureTransportResponseHandler<HandshakeResponse>() {
-                @Override
-                public HandshakeResponse read(StreamInput in) throws IOException {
-                    return new HandshakeResponse(in);
-                }
-            });
-            sendRequest(connection, HANDSHAKE_ACTION_NAME, HandshakeRequest.INSTANCE,
-                TransportRequestOptions.builder().withTimeout(handshakeTimeout).build(), futureHandler);
-            response = futureHandler.txGet();
-        } catch (Exception e) {
-            throw new IllegalStateException("handshake failed with " + node, e);
-        }
-
-        if (!clusterNamePredicate.test(response.clusterName)) {
-            throw new IllegalStateException("handshake failed, mismatched cluster name [" + response.clusterName + "] - " + node);
-        } else if (response.version.isCompatible(localNode.getVersion()) == false) {
-            throw new IllegalStateException("handshake failed, incompatible version [" + response.version + "] - " + node);
-        }
+        sendRequest(connection, HANDSHAKE_ACTION_NAME, HandshakeRequest.INSTANCE,
+            TransportRequestOptions.builder().withTimeout(handshakeTimeout).build(),
+            new ActionListenerResponseHandler<>(
+                new ActionListener<>() {
+                    @Override
+                    public void onResponse(HandshakeResponse response) {
+                        if (!clusterNamePredicate.test(response.clusterName)) {
+                            listener.onFailure(new IllegalStateException("handshake failed, mismatched cluster name [" +
+                                response.clusterName + "] - " + node.toString()));
+                        } else if (response.version.isCompatible(localNode.getVersion()) == false) {
+                            listener.onFailure(new IllegalStateException("handshake failed, incompatible version [" +
+                                response.version + "] - " + node));
+                        } else {
+                            listener.onResponse(response);
+                        }
+                    }
 
-        return response;
+                    @Override
+                    public void onFailure(Exception e) {
+                        listener.onFailure(e);
+                    }
+                }
+                , HandshakeResponse::new, ThreadPool.Names.GENERIC
+            ));
     }
 
     public ConnectionManager getConnectionManager() {

+ 4 - 6
server/src/test/java/org/elasticsearch/cluster/NodeConnectionsServiceTests.java

@@ -33,7 +33,6 @@ import org.elasticsearch.common.CheckedRunnable;
 import org.elasticsearch.common.UUIDs;
 import org.elasticsearch.common.component.Lifecycle;
 import org.elasticsearch.common.component.LifecycleListener;
-import org.elasticsearch.common.lease.Releasable;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.transport.BoundTransportAddress;
 import org.elasticsearch.common.transport.TransportAddress;
@@ -388,8 +387,9 @@ public class NodeConnectionsServiceTests extends ESTestCase {
         }
 
         @Override
-        public HandshakeResponse handshake(Transport.Connection connection, long timeout, Predicate<ClusterName> clusterNamePredicate) {
-            return new HandshakeResponse(connection.getNode(), new ClusterName(""), Version.CURRENT);
+        public void handshake(Transport.Connection connection, long timeout, Predicate<ClusterName> clusterNamePredicate,
+                              ActionListener<HandshakeResponse> listener) {
+            listener.onResponse(new HandshakeResponse(connection.getNode(), new ClusterName(""), Version.CURRENT));
         }
 
         @Override
@@ -439,7 +439,7 @@ public class NodeConnectionsServiceTests extends ESTestCase {
         }
 
         @Override
-        public Releasable openConnection(DiscoveryNode node, ConnectionProfile profile, ActionListener<Connection> listener) {
+        public void openConnection(DiscoveryNode node, ConnectionProfile profile, ActionListener<Connection> listener) {
             if (profile == null && randomConnectionExceptions && randomBoolean()) {
                 threadPool.generic().execute(() -> listener.onFailure(new ConnectTransportException(node, "simulated")));
             } else {
@@ -468,8 +468,6 @@ public class NodeConnectionsServiceTests extends ESTestCase {
                     }
                 }));
             }
-            return () -> {
-            };
         }
 
         @Override

+ 1 - 1
server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java

@@ -753,7 +753,7 @@ public class SnapshotResiliencyTests extends ESTestCase {
             disconnectedNodes.forEach(nodeName -> {
                 if (testClusterNodes.nodes.containsKey(nodeName)) {
                     final DiscoveryNode node = testClusterNodes.nodes.get(nodeName).node;
-                    testClusterNodes.nodes.values().forEach(n -> n.transportService.getConnectionManager().openConnection(node, null));
+                    testClusterNodes.nodes.values().forEach(n -> n.transportService.openConnection(node, null));
                 }
             });
         }

+ 126 - 9
server/src/test/java/org/elasticsearch/transport/ConnectionManagerTests.java

@@ -21,8 +21,8 @@ package org.elasticsearch.transport;
 
 import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.cluster.node.DiscoveryNode;
-import org.elasticsearch.common.CheckedBiConsumer;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.transport.TransportAddress;
 import org.elasticsearch.common.unit.TimeValue;
@@ -31,8 +31,12 @@ import org.elasticsearch.threadpool.ThreadPool;
 import org.junit.After;
 import org.junit.Before;
 
-import java.io.IOException;
 import java.net.InetAddress;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.BrokenBarrierException;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.CyclicBarrier;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 
@@ -94,8 +98,12 @@ public class ConnectionManagerTests extends ESTestCase {
         assertFalse(connectionManager.nodeConnected(node));
 
         AtomicReference<Transport.Connection> connectionRef = new AtomicReference<>();
-        CheckedBiConsumer<Transport.Connection, ConnectionProfile, IOException> validator = (c, p) -> connectionRef.set(c);
-        connectionManager.connectToNode(node, connectionProfile, validator);
+        ConnectionManager.ConnectionValidator validator = (c, p, l) -> {
+            connectionRef.set(c);
+            l.onResponse(null);
+        };
+        PlainActionFuture.get(
+            fut -> connectionManager.connectToNode(node, connectionProfile, validator, ActionListener.map(fut, x -> null)));
 
         assertFalse(connection.isClosed());
         assertTrue(connectionManager.nodeConnected(node));
@@ -115,7 +123,78 @@ public class ConnectionManagerTests extends ESTestCase {
         assertEquals(1, nodeDisconnectedCount.get());
     }
 
-    public void testConnectFails() {
+    public void testConcurrentConnectsAndDisconnects() throws BrokenBarrierException, InterruptedException {
+        DiscoveryNode node = new DiscoveryNode("", new TransportAddress(InetAddress.getLoopbackAddress(), 0), Version.CURRENT);
+        Transport.Connection connection = new TestConnect(node);
+        doAnswer(invocationOnMock -> {
+            ActionListener<Transport.Connection> listener = (ActionListener<Transport.Connection>) invocationOnMock.getArguments()[2];
+            if (rarely()) {
+                listener.onResponse(connection);
+            } if (frequently()) {
+                threadPool.generic().execute(() -> listener.onResponse(connection));
+            } else {
+                threadPool.generic().execute(() -> listener.onFailure(new IllegalStateException("dummy exception")));
+            }
+            return null;
+        }).when(transport).openConnection(eq(node), eq(connectionProfile), any(ActionListener.class));
+
+        assertFalse(connectionManager.nodeConnected(node));
+
+        ConnectionManager.ConnectionValidator validator = (c, p, l) -> {
+            if (rarely()) {
+                l.onResponse(null);
+            } if (frequently()) {
+                threadPool.generic().execute(() -> l.onResponse(null));
+            } else {
+                threadPool.generic().execute(() -> l.onFailure(new IllegalStateException("dummy exception")));
+            }
+        };
+
+        CyclicBarrier barrier = new CyclicBarrier(11);
+        List<Thread> threads = new ArrayList<>();
+        AtomicInteger nodeConnectedCount = new AtomicInteger();
+        AtomicInteger nodeFailureCount = new AtomicInteger();
+        for (int i = 0; i < 10; i++) {
+            Thread thread = new Thread(() -> {
+                try {
+                    barrier.await();
+                } catch (InterruptedException | BrokenBarrierException e) {
+                    throw new RuntimeException(e);
+                }
+                CountDownLatch latch = new CountDownLatch(1);
+                connectionManager.connectToNode(node, connectionProfile, validator,
+                    ActionListener.wrap(c -> {
+                        nodeConnectedCount.incrementAndGet();
+                        assert latch.getCount() == 1;
+                        latch.countDown();
+                    }, e -> {
+                        nodeFailureCount.incrementAndGet();
+                        assert latch.getCount() == 1;
+                        latch.countDown();
+                    }));
+                try {
+                    latch.await();
+                } catch (InterruptedException e) {
+                    throw new IllegalStateException(e);
+                }
+            });
+            threads.add(thread);
+            thread.start();
+        }
+
+        barrier.await();
+        threads.forEach(t -> {
+            try {
+                t.join();
+            } catch (InterruptedException e) {
+                throw new IllegalStateException(e);
+            }
+        });
+
+        assertEquals(10, nodeConnectedCount.get() + nodeFailureCount.get());
+    }
+
+    public void testConnectFailsDuringValidation() {
         AtomicInteger nodeConnectedCount = new AtomicInteger();
         AtomicInteger nodeDisconnectedCount = new AtomicInteger();
         connectionManager.addListener(new TransportConnectionListener() {
@@ -141,11 +220,11 @@ public class ConnectionManagerTests extends ESTestCase {
 
         assertFalse(connectionManager.nodeConnected(node));
 
-        CheckedBiConsumer<Transport.Connection, ConnectionProfile, IOException> validator = (c, p) -> {
-            throw new ConnectTransportException(node, "");
-        };
+        ConnectionManager.ConnectionValidator validator = (c, p, l) -> l.onFailure(new ConnectTransportException(node, ""));
 
-        expectThrows(ConnectTransportException.class, () -> connectionManager.connectToNode(node, connectionProfile, validator));
+        PlainActionFuture<Void> fut = new PlainActionFuture<>();
+        connectionManager.connectToNode(node, connectionProfile, validator, fut);
+        expectThrows(ConnectTransportException.class, () -> fut.actionGet());
 
         assertTrue(connection.isClosed());
         assertFalse(connectionManager.nodeConnected(node));
@@ -155,6 +234,44 @@ public class ConnectionManagerTests extends ESTestCase {
         assertEquals(0, nodeDisconnectedCount.get());
     }
 
+    public void testConnectFailsDuringConnect() {
+        AtomicInteger nodeConnectedCount = new AtomicInteger();
+        AtomicInteger nodeDisconnectedCount = new AtomicInteger();
+        connectionManager.addListener(new TransportConnectionListener() {
+            @Override
+            public void onNodeConnected(DiscoveryNode node) {
+                nodeConnectedCount.incrementAndGet();
+            }
+
+            @Override
+            public void onNodeDisconnected(DiscoveryNode node) {
+                nodeDisconnectedCount.incrementAndGet();
+            }
+        });
+
+
+        DiscoveryNode node = new DiscoveryNode("", new TransportAddress(InetAddress.getLoopbackAddress(), 0), Version.CURRENT);
+        doAnswer(invocationOnMock -> {
+            ActionListener<Transport.Connection> listener = (ActionListener<Transport.Connection>) invocationOnMock.getArguments()[2];
+            listener.onFailure(new ConnectTransportException(node, ""));
+            return null;
+        }).when(transport).openConnection(eq(node), eq(connectionProfile), any(ActionListener.class));
+
+        assertFalse(connectionManager.nodeConnected(node));
+
+        ConnectionManager.ConnectionValidator validator = (c, p, l) -> l.onResponse(null);
+
+        PlainActionFuture<Void> fut = new PlainActionFuture<>();
+        connectionManager.connectToNode(node, connectionProfile, validator, fut);
+        expectThrows(ConnectTransportException.class, () -> fut.actionGet());
+
+        assertFalse(connectionManager.nodeConnected(node));
+        expectThrows(NodeNotConnectedException.class, () -> connectionManager.getConnection(node));
+        assertEquals(0, connectionManager.size());
+        assertEquals(0, nodeConnectedCount.get());
+        assertEquals(0, nodeDisconnectedCount.get());
+    }
+
     private static class TestConnect extends CloseableConnection {
 
         private final DiscoveryNode node;

+ 25 - 25
server/src/test/java/org/elasticsearch/transport/RemoteClusterConnectionTests.java

@@ -1261,35 +1261,35 @@ public class RemoteClusterConnectionTests extends ESTestCase {
                     // route by seed hostname
                     proxyNode = proxyMapping.get(node.getHostName());
                 }
-            return t.openConnection(proxyNode, profile, ActionListener.delegateFailure(listener,
-                (delegatedListener, connection) -> delegatedListener.onResponse(
-                    new Transport.Connection() {
-                        @Override
-                        public DiscoveryNode getNode() {
-                            return node;
-                        }
+                t.openConnection(proxyNode, profile, ActionListener.delegateFailure(listener,
+                    (delegatedListener, connection) -> delegatedListener.onResponse(
+                        new Transport.Connection() {
+                            @Override
+                            public DiscoveryNode getNode() {
+                                return node;
+                            }
 
-                        @Override
-                        public void sendRequest(long requestId, String action, TransportRequest request,
-                                                TransportRequestOptions options) throws IOException {
-                            connection.sendRequest(requestId, action, request, options);
-                        }
+                            @Override
+                            public void sendRequest(long requestId, String action, TransportRequest request,
+                                                    TransportRequestOptions options) throws IOException {
+                                connection.sendRequest(requestId, action, request, options);
+                            }
 
-                        @Override
-                        public void addCloseListener(ActionListener<Void> listener) {
-                            connection.addCloseListener(listener);
-                        }
+                            @Override
+                            public void addCloseListener(ActionListener<Void> listener) {
+                                connection.addCloseListener(listener);
+                            }
 
-                        @Override
-                        public boolean isClosed() {
-                            return connection.isClosed();
-                        }
+                            @Override
+                            public boolean isClosed() {
+                                return connection.isClosed();
+                            }
 
-                        @Override
-                        public void close() {
-                            connection.close();
-                        }
-                    })));
+                            @Override
+                            public void close() {
+                                connection.close();
+                            }
+                        })));
         });
         return stubbableTransport;
     }

+ 5 - 3
server/src/test/java/org/elasticsearch/transport/TransportServiceHandshakeTests.java

@@ -20,6 +20,8 @@
 package org.elasticsearch.transport;
 
 import org.elasticsearch.Version;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.network.NetworkService;
@@ -108,7 +110,7 @@ public class TransportServiceHandshakeTests extends ESTestCase {
             emptySet(),
             Version.CURRENT.minimumCompatibilityVersion());
         try (Transport.Connection connection = handleA.transportService.openConnection(discoveryNode, TestProfiles.LIGHT_PROFILE)){
-            DiscoveryNode connectedNode = handleA.transportService.handshake(connection, timeout);
+            DiscoveryNode connectedNode = PlainActionFuture.get(fut -> handleA.transportService.handshake(connection, timeout, fut));
             assertNotNull(connectedNode);
             // the name and version should be updated
             assertEquals(connectedNode.getName(), "TS_B");
@@ -130,7 +132,7 @@ public class TransportServiceHandshakeTests extends ESTestCase {
         IllegalStateException ex = expectThrows(IllegalStateException.class, () -> {
             try (Transport.Connection connection = handleA.transportService.openConnection(discoveryNode,
                 TestProfiles.LIGHT_PROFILE)) {
-                handleA.transportService.handshake(connection, timeout);
+                PlainActionFuture.get(fut -> handleA.transportService.handshake(connection, timeout, ActionListener.map(fut, x -> null)));
             }
         });
         assertThat(ex.getMessage(), containsString("handshake failed, mismatched cluster name [Cluster [b]]"));
@@ -151,7 +153,7 @@ public class TransportServiceHandshakeTests extends ESTestCase {
         IllegalStateException ex = expectThrows(IllegalStateException.class, () -> {
             try (Transport.Connection connection = handleA.transportService.openConnection(discoveryNode,
                 TestProfiles.LIGHT_PROFILE)) {
-                handleA.transportService.handshake(connection, timeout);
+                PlainActionFuture.get(fut -> handleA.transportService.handshake(connection, timeout, ActionListener.map(fut, x -> null)));
             }
         });
         assertThat(ex.getMessage(), containsString("handshake failed, incompatible version"));

+ 23 - 40
test/framework/src/main/java/org/elasticsearch/test/disruption/DisruptableMockTransport.java

@@ -23,7 +23,6 @@ import org.apache.logging.log4j.message.ParameterizedMessage;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.common.Nullable;
-import org.elasticsearch.common.lease.Releasable;
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.transport.BoundTransportAddress;
@@ -48,7 +47,6 @@ import java.util.Set;
 import java.util.function.Function;
 
 import static org.elasticsearch.test.ESTestCase.copyWriteable;
-import static org.elasticsearch.transport.TransportService.HANDSHAKE_ACTION_NAME;
 
 public abstract class DisruptableMockTransport extends MockTransport {
     private final DiscoveryNode localNode;
@@ -65,15 +63,6 @@ public abstract class DisruptableMockTransport extends MockTransport {
 
     protected abstract void execute(Runnable runnable);
 
-    protected final void execute(String action, Runnable runnable) {
-        // handshake needs to run inline as the caller blockingly waits on the result
-        if (action.equals(HANDSHAKE_ACTION_NAME)) {
-            runnable.run();
-        } else {
-            execute(runnable);
-        }
-    }
-
     public DiscoveryNode getLocalNode() {
         return localNode;
     }
@@ -86,30 +75,30 @@ public abstract class DisruptableMockTransport extends MockTransport {
     }
 
     @Override
-    public Releasable openConnection(DiscoveryNode node, ConnectionProfile profile, ActionListener<Connection> listener) {
+    public void openConnection(DiscoveryNode node, ConnectionProfile profile, ActionListener<Connection> listener) {
         final Optional<DisruptableMockTransport> optionalMatchingTransport = getDisruptableMockTransport(node.getAddress());
         if (optionalMatchingTransport.isPresent()) {
             final DisruptableMockTransport matchingTransport = optionalMatchingTransport.get();
             final ConnectionStatus connectionStatus = getConnectionStatus(matchingTransport.getLocalNode());
             if (connectionStatus != ConnectionStatus.CONNECTED) {
-                throw new ConnectTransportException(node, "node [" + node + "] is [" + connectionStatus + "] not [CONNECTED]");
-            }
-
-            listener.onResponse(new CloseableConnection() {
-                @Override
-                public DiscoveryNode getNode() {
-                    return node;
-                }
+                listener.onFailure(
+                    new ConnectTransportException(node, "node [" + node + "] is [" + connectionStatus + "] not [CONNECTED]"));
+            } else {
+                listener.onResponse(new CloseableConnection() {
+                    @Override
+                    public DiscoveryNode getNode() {
+                        return node;
+                    }
 
-                @Override
-                public void sendRequest(long requestId, String action, TransportRequest request, TransportRequestOptions options)
-                    throws TransportException {
-                    onSendRequest(requestId, action, request, matchingTransport);
-                }
-            });
-            return () -> {};
+                    @Override
+                    public void sendRequest(long requestId, String action, TransportRequest request, TransportRequestOptions options)
+                        throws TransportException {
+                        onSendRequest(requestId, action, request, matchingTransport);
+                    }
+                });
+            }
         } else {
-            throw new ConnectTransportException(node, "node [" + node + "] does not exist");
+            listener.onFailure(new ConnectTransportException(node, "node " + node + " does not exist"));
         }
     }
 
@@ -119,7 +108,7 @@ public abstract class DisruptableMockTransport extends MockTransport {
         assert destinationTransport.getLocalNode().equals(getLocalNode()) == false :
             "non-local message from " + getLocalNode() + " to itself";
 
-        destinationTransport.execute(action, new Runnable() {
+        destinationTransport.execute(new Runnable() {
             @Override
             public void run() {
                 final ConnectionStatus connectionStatus = getConnectionStatus(destinationTransport.getLocalNode());
@@ -169,18 +158,11 @@ public abstract class DisruptableMockTransport extends MockTransport {
     }
 
     protected void onBlackholedDuringSend(long requestId, String action, DisruptableMockTransport destinationTransport) {
-        if (action.equals(HANDSHAKE_ACTION_NAME)) {
-            logger.trace("ignoring blackhole and delivering {}",
-                getRequestDescription(requestId, action, destinationTransport.getLocalNode()));
-            // handshakes always have a timeout, and are sent in a blocking fashion, so we must respond with an exception.
-            destinationTransport.execute(action, getDisconnectException(requestId, action, destinationTransport.getLocalNode()));
-        } else {
-            logger.trace("dropping {}", getRequestDescription(requestId, action, destinationTransport.getLocalNode()));
-        }
+        logger.trace("dropping {}", getRequestDescription(requestId, action, destinationTransport.getLocalNode()));
     }
 
     protected void onDisconnectedDuringSend(long requestId, String action, DisruptableMockTransport destinationTransport) {
-        destinationTransport.execute(action, getDisconnectException(requestId, action, destinationTransport.getLocalNode()));
+        destinationTransport.execute(getDisconnectException(requestId, action, destinationTransport.getLocalNode()));
     }
 
     protected void onConnectedDuringSend(long requestId, String action, TransportRequest request,
@@ -205,7 +187,7 @@ public abstract class DisruptableMockTransport extends MockTransport {
 
             @Override
             public void sendResponse(final TransportResponse response) {
-                execute(action, new Runnable() {
+                execute(new Runnable() {
                     @Override
                     public void run() {
                         final ConnectionStatus connectionStatus = destinationTransport.getConnectionStatus(getLocalNode());
@@ -234,7 +216,8 @@ public abstract class DisruptableMockTransport extends MockTransport {
 
             @Override
             public void sendResponse(Exception exception) {
-                execute(action, new Runnable() {
+
+                execute(new Runnable() {
                     @Override
                     public void run() {
                         final ConnectionStatus connectionStatus = destinationTransport.getConnectionStatus(getLocalNode());

+ 1 - 3
test/framework/src/main/java/org/elasticsearch/test/transport/MockTransport.java

@@ -31,7 +31,6 @@ import org.elasticsearch.common.component.LifecycleListener;
 import org.elasticsearch.common.io.stream.BytesStreamOutput;
 import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
-import org.elasticsearch.common.lease.Releasable;
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.transport.BoundTransportAddress;
@@ -164,9 +163,8 @@ public class MockTransport implements Transport, LifecycleComponent {
     }
 
     @Override
-    public Releasable openConnection(DiscoveryNode node, ConnectionProfile profile, ActionListener<Connection> listener) {
+    public void openConnection(DiscoveryNode node, ConnectionProfile profile, ActionListener<Connection> listener) {
         listener.onResponse(createConnection(node));
-        return () -> {};
     }
 
     public Connection createConnection(DiscoveryNode node) {

+ 9 - 17
test/framework/src/main/java/org/elasticsearch/test/transport/MockTransportService.java

@@ -29,7 +29,6 @@ import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.UUIDs;
 import org.elasticsearch.common.io.stream.BytesStreamOutput;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
-import org.elasticsearch.common.lease.Releasable;
 import org.elasticsearch.common.network.NetworkService;
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Setting;
@@ -222,10 +221,8 @@ public final class MockTransportService extends TransportService {
      * is added to fail as well.
      */
     public void addFailToSendNoConnectRule(TransportAddress transportAddress) {
-        transport().addConnectBehavior(transportAddress, (transport, discoveryNode, profile, listener) -> {
-            listener.onFailure(new ConnectTransportException(discoveryNode, "DISCONNECT: simulated"));
-            return () -> {};
-        });
+        transport().addConnectBehavior(transportAddress, (transport, discoveryNode, profile, listener) ->
+            listener.onFailure(new ConnectTransportException(discoveryNode, "DISCONNECT: simulated")));
 
         transport().addSendBehavior(transportAddress, (connection, requestId, action, request, options) -> {
             connection.close();
@@ -278,10 +275,8 @@ public final class MockTransportService extends TransportService {
      * and failing to connect once the rule was added.
      */
     public void addUnresponsiveRule(TransportAddress transportAddress) {
-        transport().addConnectBehavior(transportAddress, (transport, discoveryNode, profile, listener) -> {
-            listener.onFailure(new ConnectTransportException(discoveryNode, "UNRESPONSIVE: simulated"));
-            return () -> {};
-        });
+        transport().addConnectBehavior(transportAddress, (transport, discoveryNode, profile, listener) ->
+            listener.onFailure(new ConnectTransportException(discoveryNode, "UNRESPONSIVE: simulated")));
 
         transport().addSendBehavior(transportAddress, new StubbableTransport.SendRequestBehavior() {
             private Set<Transport.Connection> toClose = ConcurrentHashMap.newKeySet();
@@ -331,11 +326,12 @@ public final class MockTransportService extends TransportService {
         transport().addConnectBehavior(transportAddress, new StubbableTransport.OpenConnectionBehavior() {
             private CountDownLatch stopLatch = new CountDownLatch(1);
             @Override
-            public Releasable openConnection(Transport transport, DiscoveryNode discoveryNode,
+            public void openConnection(Transport transport, DiscoveryNode discoveryNode,
                                              ConnectionProfile profile, ActionListener<Transport.Connection> listener) {
                 TimeValue delay = delaySupplier.get();
                 if (delay.millis() <= 0) {
-                    return original.openConnection(discoveryNode, profile, listener);
+                    original.openConnection(discoveryNode, profile, listener);
+                    return;
                 }
 
                 // TODO: Replace with proper setting
@@ -343,17 +339,13 @@ public final class MockTransportService extends TransportService {
                 try {
                     if (delay.millis() < connectingTimeout.millis()) {
                         stopLatch.await(delay.millis(), TimeUnit.MILLISECONDS);
-                        return original.openConnection(discoveryNode, profile, listener);
+                        original.openConnection(discoveryNode, profile, listener);
                     } else {
                         stopLatch.await(connectingTimeout.millis(), TimeUnit.MILLISECONDS);
                         listener.onFailure(new ConnectTransportException(discoveryNode, "UNRESPONSIVE: simulated"));
-                        return () -> {
-                        };
                     }
                 } catch (InterruptedException e) {
                     listener.onFailure(new ConnectTransportException(discoveryNode, "UNRESPONSIVE: simulated"));
-                    return () -> {
-                    };
                 }
             }
 
@@ -524,7 +516,7 @@ public final class MockTransportService extends TransportService {
     }
 
     @Override
-    public Transport.Connection openConnection(DiscoveryNode node, ConnectionProfile profile) throws IOException {
+    public Transport.Connection openConnection(DiscoveryNode node, ConnectionProfile profile) {
         Transport.Connection connection = super.openConnection(node, profile);
 
         synchronized (openConnections) {

+ 5 - 6
test/framework/src/main/java/org/elasticsearch/test/transport/StubbableConnectionManager.java

@@ -18,8 +18,8 @@
  */
 package org.elasticsearch.test.transport;
 
+import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.cluster.node.DiscoveryNode;
-import org.elasticsearch.common.CheckedBiConsumer;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.transport.TransportAddress;
 import org.elasticsearch.transport.ConnectTransportException;
@@ -28,7 +28,6 @@ import org.elasticsearch.transport.ConnectionProfile;
 import org.elasticsearch.transport.Transport;
 import org.elasticsearch.transport.TransportConnectionListener;
 
-import java.io.IOException;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentMap;
 
@@ -80,8 +79,8 @@ public class StubbableConnectionManager extends ConnectionManager {
     }
 
     @Override
-    public Transport.Connection openConnection(DiscoveryNode node, ConnectionProfile connectionProfile) {
-        return delegate.openConnection(node, connectionProfile);
+    public void openConnection(DiscoveryNode node, ConnectionProfile connectionProfile, ActionListener<Transport.Connection> listener) {
+        delegate.openConnection(node, connectionProfile, listener);
     }
 
     @Override
@@ -110,9 +109,9 @@ public class StubbableConnectionManager extends ConnectionManager {
 
     @Override
     public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile,
-                              CheckedBiConsumer<Transport.Connection, ConnectionProfile, IOException> connectionValidator)
+                              ConnectionValidator connectionValidator, ActionListener<Void> listener)
         throws ConnectTransportException {
-        delegate.connectToNode(node, connectionProfile, connectionValidator);
+        delegate.connectToNode(node, connectionProfile, connectionValidator, listener);
     }
 
     @Override

+ 5 - 6
test/framework/src/main/java/org/elasticsearch/test/transport/StubbableTransport.java

@@ -24,7 +24,6 @@ import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.common.component.Lifecycle;
 import org.elasticsearch.common.component.LifecycleListener;
-import org.elasticsearch.common.lease.Releasable;
 import org.elasticsearch.common.transport.BoundTransportAddress;
 import org.elasticsearch.common.transport.TransportAddress;
 import org.elasticsearch.transport.ConnectionProfile;
@@ -128,7 +127,7 @@ public final class StubbableTransport implements Transport {
     }
 
     @Override
-    public Releasable openConnection(DiscoveryNode node, ConnectionProfile profile, ActionListener<Connection> listener) {
+    public void openConnection(DiscoveryNode node, ConnectionProfile profile, ActionListener<Connection> listener) {
         TransportAddress address = node.getAddress();
         OpenConnectionBehavior behavior = connectBehaviors.getOrDefault(address, defaultConnectBehavior);
 
@@ -137,9 +136,9 @@ public final class StubbableTransport implements Transport {
                 (delegatedListener, connection) -> delegatedListener.onResponse(new WrappedConnection(connection)));
 
         if (behavior == null) {
-            return delegate.openConnection(node, profile, wrappedListener);
+            delegate.openConnection(node, profile, wrappedListener);
         } else {
-            return behavior.openConnection(delegate, node, profile, wrappedListener);
+            behavior.openConnection(delegate, node, profile, wrappedListener);
         }
     }
 
@@ -247,8 +246,8 @@ public final class StubbableTransport implements Transport {
     @FunctionalInterface
     public interface OpenConnectionBehavior {
 
-        Releasable openConnection(Transport transport, DiscoveryNode discoveryNode, ConnectionProfile profile,
-                                  ActionListener<Connection> listener);
+        void openConnection(Transport transport, DiscoveryNode discoveryNode, ConnectionProfile profile,
+                            ActionListener<Connection> listener);
 
         default void clearCallback() {}
     }

+ 8 - 2
test/framework/src/test/java/org/elasticsearch/test/disruption/DisruptableMockTransportTests.java

@@ -20,6 +20,7 @@
 package org.elasticsearch.test.disruption;
 
 import org.elasticsearch.Version;
+import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.cluster.coordination.DeterministicTaskQueue;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.common.collect.Tuple;
@@ -146,8 +147,13 @@ public class DisruptableMockTransportTests extends ESTestCase {
         service1.start();
         service2.start();
 
-        service1.connectToNode(node2);
-        service2.connectToNode(node1);
+        final PlainActionFuture<Void> fut1 = new PlainActionFuture<>();
+        service1.connectToNode(node2, fut1);
+        final PlainActionFuture<Void> fut2 = new PlainActionFuture<>();
+        service2.connectToNode(node1, fut2);
+        deterministicTaskQueue.runAllTasksInTimeOrder();
+        assertTrue(fut1.isDone());
+        assertTrue(fut2.isDone());
     }
 
     private TransportRequestHandler<TransportRequest.Empty> requestHandlerShouldNotBeCalled() {