Browse Source

Mock connections more accurately in DisruptableMockTransport (#37296)

This commit moves DisruptableMockTransport to use a more accurate representation of connection
management, which allows to use the full connection manager and does not require mocking out
any behavior. With this, we can implement restarting nodes in CoordinatorTests.
Yannick Welsch 6 years ago
parent
commit
f4abf9628a

+ 1 - 1
server/src/main/java/org/elasticsearch/cluster/coordination/CoordinationState.java

@@ -209,7 +209,7 @@ public class CoordinationState {
      * @throws CoordinationStateRejectedException if the arguments were incompatible with the current state of this object.
      */
     public boolean handleJoin(Join join) {
-        assert join.getTargetNode().equals(localNode) : "handling join " + join + " for the wrong node " + localNode;
+        assert join.targetMatches(localNode) : "handling join " + join + " for the wrong node " + localNode;
 
         if (join.getTerm() != getCurrentTerm()) {
             logger.debug("handleJoin: ignored join due to term mismatch (expected: [{}], actual: [{}])",

+ 1 - 1
server/src/main/java/org/elasticsearch/cluster/coordination/Coordinator.java

@@ -310,7 +310,7 @@ public class Coordinator extends AbstractLifecycleComponent implements Discovery
 
     private static Optional<Join> joinWithDestination(Optional<Join> lastJoin, DiscoveryNode leader, long term) {
         if (lastJoin.isPresent()
-            && lastJoin.get().getTargetNode().getId().equals(leader.getId())
+            && lastJoin.get().targetMatches(leader)
             && lastJoin.get().getTerm() == term) {
             return lastJoin;
         }

+ 4 - 0
server/src/main/java/org/elasticsearch/cluster/coordination/Join.java

@@ -78,6 +78,10 @@ public class Join implements Writeable {
         return targetNode;
     }
 
+    public boolean targetMatches(DiscoveryNode matchingNode) {
+        return targetNode.getId().equals(matchingNode.getId());
+    }
+
     public long getLastAcceptedVersion() {
         return lastAcceptedVersion;
     }

+ 132 - 65
server/src/test/java/org/elasticsearch/cluster/coordination/CoordinatorTests.java

@@ -26,6 +26,7 @@ import org.apache.logging.log4j.message.ParameterizedMessage;
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.cluster.ClusterModule;
 import org.elasticsearch.cluster.ClusterName;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.ClusterStateUpdateTask;
@@ -42,6 +43,10 @@ import org.elasticsearch.cluster.node.DiscoveryNode.Role;
 import org.elasticsearch.cluster.service.ClusterApplier;
 import org.elasticsearch.common.Randomness;
 import org.elasticsearch.common.UUIDs;
+import org.elasticsearch.common.io.stream.BytesStreamOutput;
+import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.lease.Releasable;
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Setting;
@@ -77,7 +82,6 @@ import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.function.BiConsumer;
 import java.util.function.Consumer;
 import java.util.function.Function;
-import java.util.function.Predicate;
 import java.util.function.Supplier;
 import java.util.function.UnaryOperator;
 import java.util.stream.Collectors;
@@ -107,7 +111,6 @@ import static org.elasticsearch.discovery.DiscoverySettings.NO_MASTER_BLOCK_SETT
 import static org.elasticsearch.discovery.DiscoverySettings.NO_MASTER_BLOCK_WRITES;
 import static org.elasticsearch.discovery.PeerFinder.DISCOVERY_FIND_PEERS_INTERVAL_SETTING;
 import static org.elasticsearch.node.Node.NODE_NAME_SETTING;
-import static org.elasticsearch.transport.TransportService.HANDSHAKE_ACTION_NAME;
 import static org.elasticsearch.transport.TransportService.NOOP_TRANSPORT_INTERCEPTOR;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.empty;
@@ -930,7 +933,7 @@ public class CoordinatorTests extends ESTestCase {
 
         final ClusterNode leader = cluster.getAnyLeader();
         final ClusterNode nonLeader = cluster.getAnyNodeExcept(leader);
-        onNode(nonLeader.getLocalNode(), () -> {
+        nonLeader.onNode(() -> {
             logger.debug("forcing {} to become candidate", nonLeader.getId());
             synchronized (nonLeader.coordinator.mutex) {
                 nonLeader.coordinator.becomeCandidate("forced");
@@ -1161,6 +1164,11 @@ public class CoordinatorTests extends ESTestCase {
             assertThat("may reconnect disconnected nodes, probably unexpected", disconnectedNodes, empty());
             assertThat("may reconnect blackholed nodes, probably unexpected", blackholedNodes, empty());
 
+            final List<Runnable> cleanupActions = new ArrayList<>();
+            cleanupActions.add(disconnectedNodes::clear);
+            cleanupActions.add(blackholedNodes::clear);
+            cleanupActions.add(() -> disruptStorage = false);
+
             final int randomSteps = scaledRandomIntBetween(10, 10000);
             logger.info("--> start of safety phase of at least [{}] steps", randomSteps);
 
@@ -1183,7 +1191,7 @@ public class CoordinatorTests extends ESTestCase {
                     if (rarely()) {
                         final ClusterNode clusterNode = getAnyNodePreferringLeaders();
                         final int newValue = randomInt();
-                        onNode(clusterNode.getLocalNode(), () -> {
+                        clusterNode.onNode(() -> {
                             logger.debug("----> [runRandomly {}] proposing new value [{}] to [{}]",
                                 thisStep, newValue, clusterNode.getId());
                             clusterNode.submitValue(newValue);
@@ -1191,15 +1199,34 @@ public class CoordinatorTests extends ESTestCase {
                     } else if (rarely()) {
                         final ClusterNode clusterNode = getAnyNodePreferringLeaders();
                         final boolean autoShrinkVotingConfiguration = randomBoolean();
-                        onNode(clusterNode.getLocalNode(),
+                        clusterNode.onNode(
                             () -> {
                                 logger.debug("----> [runRandomly {}] setting auto-shrink configuration to {} on {}",
                                     thisStep, autoShrinkVotingConfiguration, clusterNode.getId());
                                 clusterNode.submitSetAutoShrinkVotingConfiguration(autoShrinkVotingConfiguration);
                             }).run();
+                    } else if (rarely()) {
+                        // reboot random node
+                        final ClusterNode clusterNode = getAnyNode();
+                        logger.debug("----> [runRandomly {}] rebooting [{}]", thisStep, clusterNode.getId());
+                        clusterNode.close();
+                        clusterNodes.forEach(
+                            cn -> deterministicTaskQueue.scheduleNow(cn.onNode(
+                                new Runnable() {
+                                    @Override
+                                    public void run() {
+                                        cn.transportService.disconnectFromNode(clusterNode.getLocalNode());
+                                    }
+
+                                    @Override
+                                    public String toString() {
+                                        return "disconnect from " + clusterNode.getLocalNode() + " after shutdown";
+                                    }
+                                })));
+                        clusterNodes.replaceAll(cn -> cn == clusterNode ? cn.restartedNode() : cn);
                     } else if (rarely()) {
                         final ClusterNode clusterNode = getAnyNode();
-                        onNode(clusterNode.getLocalNode(), () -> {
+                        clusterNode.onNode(() -> {
                             logger.debug("----> [runRandomly {}] forcing {} to become candidate", thisStep, clusterNode.getId());
                             synchronized (clusterNode.coordinator.mutex) {
                                 clusterNode.coordinator.becomeCandidate("runRandomly");
@@ -1227,7 +1254,7 @@ public class CoordinatorTests extends ESTestCase {
                         }
                     } else if (rarely()) {
                         final ClusterNode clusterNode = getAnyNode();
-                        onNode(clusterNode.getLocalNode(),
+                        clusterNode.onNode(
                             () -> {
                                 logger.debug("----> [runRandomly {}] applying initial configuration {} to {}",
                                     thisStep, initialConfiguration, clusterNode.getId());
@@ -1252,9 +1279,9 @@ public class CoordinatorTests extends ESTestCase {
                 assertConsistentStates();
             }
 
-            disconnectedNodes.clear();
-            blackholedNodes.clear();
-            disruptStorage = false;
+            logger.debug("running {} cleanup actions", cleanupActions.size());
+            cleanupActions.forEach(Runnable::run);
+            logger.debug("finished running cleanup actions");
         }
 
         private void assertConsistentStates() {
@@ -1406,18 +1433,28 @@ public class CoordinatorTests extends ESTestCase {
             return randomFrom(allLeaders);
         }
 
+        private final ConnectionStatus preferredUnknownNodeConnectionStatus =
+            randomFrom(ConnectionStatus.DISCONNECTED, ConnectionStatus.BLACK_HOLE);
+
         private ConnectionStatus getConnectionStatus(DiscoveryNode sender, DiscoveryNode destination) {
             ConnectionStatus connectionStatus;
             if (blackholedNodes.contains(sender.getId()) || blackholedNodes.contains(destination.getId())) {
                 connectionStatus = ConnectionStatus.BLACK_HOLE;
             } else if (disconnectedNodes.contains(sender.getId()) || disconnectedNodes.contains(destination.getId())) {
                 connectionStatus = ConnectionStatus.DISCONNECTED;
-            } else {
+            } else if (nodeExists(sender) && nodeExists(destination)) {
                 connectionStatus = ConnectionStatus.CONNECTED;
+            } else {
+                connectionStatus = usually() ? preferredUnknownNodeConnectionStatus :
+                    randomFrom(ConnectionStatus.DISCONNECTED, ConnectionStatus.BLACK_HOLE);
             }
             return connectionStatus;
         }
 
+        boolean nodeExists(DiscoveryNode node) {
+            return clusterNodes.stream().anyMatch(cn -> cn.getLocalNode().equals(node));
+        }
+
         ClusterNode getAnyMasterEligibleNode() {
             return randomFrom(clusterNodes.stream().filter(n -> n.getLocalNode().isMasterNode()).collect(Collectors.toList()));
         }
@@ -1486,7 +1523,7 @@ public class CoordinatorTests extends ESTestCase {
 
             private final int nodeIndex;
             private Coordinator coordinator;
-            private DiscoveryNode localNode;
+            private final DiscoveryNode localNode;
             private final PersistedState persistedState;
             private FakeClusterApplier clusterApplier;
             private AckedFakeThreadPoolMasterService masterService;
@@ -1496,63 +1533,34 @@ public class CoordinatorTests extends ESTestCase {
             private ClusterStateApplyResponse clusterStateApplyResponse = ClusterStateApplyResponse.SUCCEED;
 
             ClusterNode(int nodeIndex, boolean masterEligible) {
-                this.nodeIndex = nodeIndex;
-                localNode = createDiscoveryNode(masterEligible);
-                persistedState = new MockPersistedState(0L,
-                    clusterState(0L, 0L, localNode, VotingConfiguration.EMPTY_CONFIG, VotingConfiguration.EMPTY_CONFIG, 0L));
-                onNode(localNode, this::setUp).run();
+                this(nodeIndex, createDiscoveryNode(nodeIndex, masterEligible),
+                    localNode -> new MockPersistedState(0L,
+                        clusterState(0L, 0L, localNode, VotingConfiguration.EMPTY_CONFIG, VotingConfiguration.EMPTY_CONFIG, 0L)));
             }
 
-            private DiscoveryNode createDiscoveryNode(boolean masterEligible) {
-                final TransportAddress address = buildNewFakeTransportAddress();
-                return new DiscoveryNode("", "node" + nodeIndex,
-                    UUIDs.randomBase64UUID(random()), // generated deterministically for repeatable tests
-                    address.address().getHostString(), address.getAddress(), address, Collections.emptyMap(),
-                    masterEligible ? EnumSet.allOf(Role.class) : emptySet(), Version.CURRENT);
+            ClusterNode(int nodeIndex, DiscoveryNode localNode, Function<DiscoveryNode, PersistedState> persistedStateSupplier) {
+                this.nodeIndex = nodeIndex;
+                this.localNode = localNode;
+                persistedState = persistedStateSupplier.apply(localNode);
+                onNodeLog(localNode, this::setUp).run();
             }
 
             private void setUp() {
-                mockTransport = new DisruptableMockTransport(logger) {
-                    @Override
-                    protected DiscoveryNode getLocalNode() {
-                        return localNode;
-                    }
-
+                mockTransport = new DisruptableMockTransport(localNode, logger) {
                     @Override
-                    protected ConnectionStatus getConnectionStatus(DiscoveryNode sender, DiscoveryNode destination) {
-                        return Cluster.this.getConnectionStatus(sender, destination);
+                    protected void execute(Runnable runnable) {
+                        deterministicTaskQueue.scheduleNow(onNode(runnable));
                     }
 
                     @Override
-                    protected Optional<DisruptableMockTransport> getDisruptedCapturingTransport(DiscoveryNode node, String action) {
-                        final Predicate<ClusterNode> matchesDestination;
-                        if (action.equals(HANDSHAKE_ACTION_NAME)) {
-                            matchesDestination = n -> n.getLocalNode().getAddress().equals(node.getAddress());
-                        } else {
-                            matchesDestination = n -> n.getLocalNode().equals(node);
-                        }
-                        return clusterNodes.stream().filter(matchesDestination).findAny().map(cn -> cn.mockTransport);
+                    protected ConnectionStatus getConnectionStatus(DiscoveryNode destination) {
+                        return Cluster.this.getConnectionStatus(getLocalNode(), destination);
                     }
 
                     @Override
-                    protected void handle(DiscoveryNode sender, DiscoveryNode destination, String action, Runnable doDelivery) {
-                        // handshake needs to run inline as the caller blockingly waits on the result
-                        if (action.equals(HANDSHAKE_ACTION_NAME)) {
-                            onNode(destination, doDelivery).run();
-                        } else {
-                            deterministicTaskQueue.scheduleNow(onNode(destination, doDelivery));
-                        }
-                    }
-
-                    @Override
-                    protected void onBlackholedDuringSend(long requestId, String action, DiscoveryNode destination) {
-                        if (action.equals(HANDSHAKE_ACTION_NAME)) {
-                            logger.trace("ignoring blackhole and delivering {}", getRequestDescription(requestId, action, destination));
-                            // handshakes always have a timeout, and are sent in a blocking fashion, so we must respond with an exception.
-                            sendFromTo(destination, getLocalNode(), action, getDisconnectException(requestId, action, destination));
-                        } else {
-                            super.onBlackholedDuringSend(requestId, action, destination);
-                        }
+                    protected Optional<DisruptableMockTransport> getDisruptableMockTransport(TransportAddress address) {
+                        return clusterNodes.stream().map(cn -> cn.mockTransport)
+                            .filter(transport -> transport.getLocalNode().getAddress().equals(address)).findAny();
                     }
                 };
 
@@ -1563,9 +1571,9 @@ public class CoordinatorTests extends ESTestCase {
                 final ClusterSettings clusterSettings = new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
                 clusterApplier = new FakeClusterApplier(settings, clusterSettings);
                 masterService = new AckedFakeThreadPoolMasterService("test_node", "test",
-                    runnable -> deterministicTaskQueue.scheduleNow(onNode(localNode, runnable)));
+                    runnable -> deterministicTaskQueue.scheduleNow(onNode(runnable)));
                 transportService = mockTransport.createTransportService(
-                    settings, deterministicTaskQueue.getThreadPool(runnable -> onNode(localNode, runnable)), NOOP_TRANSPORT_INTERCEPTOR,
+                    settings, deterministicTaskQueue.getThreadPool(this::onNode), NOOP_TRANSPORT_INTERCEPTOR,
                     a -> localNode, null, emptySet());
                 final Collection<BiConsumer<DiscoveryNode, ClusterState>> onJoinValidators =
                     Collections.singletonList((dn, cs) -> extraJoinValidators.forEach(validator -> validator.accept(dn, cs)));
@@ -1574,6 +1582,7 @@ public class CoordinatorTests extends ESTestCase {
                     Cluster.this::provideUnicastHosts, clusterApplier, onJoinValidators, Randomness.get());
                 masterService.setClusterStatePublisher(coordinator);
 
+                logger.trace("starting up [{}]", localNode);
                 transportService.start();
                 transportService.acceptIncomingRequests();
                 masterService.start();
@@ -1581,6 +1590,37 @@ public class CoordinatorTests extends ESTestCase {
                 coordinator.startInitialJoin();
             }
 
+            void close() {
+                logger.trace("taking down [{}]", localNode);
+                //transportService.stop(); // does blocking stuff :/
+                masterService.stop();
+                coordinator.stop();
+                //transportService.close(); // does blocking stuff :/
+                masterService.close();
+                coordinator.close();
+            }
+
+            ClusterNode restartedNode() {
+                final TransportAddress address = randomBoolean() ? buildNewFakeTransportAddress() : localNode.getAddress();
+                final DiscoveryNode newLocalNode = new DiscoveryNode(localNode.getName(), localNode.getId(),
+                    UUIDs.randomBase64UUID(random()), // generated deterministically for repeatable tests
+                    address.address().getHostString(), address.getAddress(), address, Collections.emptyMap(),
+                    localNode.isMasterNode() ? EnumSet.allOf(Role.class) : emptySet(), Version.CURRENT);
+                final PersistedState newPersistedState;
+                try {
+                    BytesStreamOutput outStream = new BytesStreamOutput();
+                    outStream.setVersion(Version.CURRENT);
+                    persistedState.getLastAcceptedState().writeTo(outStream);
+                    StreamInput inStream = new NamedWriteableAwareStreamInput(outStream.bytes().streamInput(),
+                        new NamedWriteableRegistry(ClusterModule.getNamedWriteables()));
+                    newPersistedState = new MockPersistedState(persistedState.getCurrentTerm(),
+                        ClusterState.readFrom(inStream, newLocalNode)); // adapts it to new localNode instance
+                } catch (IOException e) {
+                    throw new UncheckedIOException(e);
+                }
+                return new ClusterNode(nodeIndex, newLocalNode, node -> newPersistedState);
+            }
+
             private PersistedState getPersistedState() {
                 return persistedState;
             }
@@ -1615,6 +1655,25 @@ public class CoordinatorTests extends ESTestCase {
                 return clusterStateApplyResponse;
             }
 
+            Runnable onNode(Runnable runnable) {
+                final Runnable wrapped = onNodeLog(localNode, runnable);
+                return new Runnable() {
+                    @Override
+                    public void run() {
+                        if (clusterNodes.contains(ClusterNode.this) == false) {
+                            logger.trace("ignoring runnable {} from node {} as node has been removed from cluster", runnable, localNode);
+                            return;
+                        }
+                        wrapped.run();
+                    }
+
+                    @Override
+                    public String toString() {
+                        return wrapped.toString();
+                    }
+                };
+            }
+
             void submitSetAutoShrinkVotingConfiguration(final boolean autoShrinkVotingConfiguration) {
                 submitUpdateTask("set master nodes failure tolerance [" + autoShrinkVotingConfiguration + "]", cs ->
                     ClusterState.builder(cs).metaData(
@@ -1633,7 +1692,7 @@ public class CoordinatorTests extends ESTestCase {
 
             AckCollector submitUpdateTask(String source, UnaryOperator<ClusterState> clusterStateUpdate) {
                 final AckCollector ackCollector = new AckCollector();
-                onNode(localNode, () -> {
+                onNode(() -> {
                     logger.trace("[{}] submitUpdateTask: enqueueing [{}]", localNode.getId(), source);
                     final long submittedTerm = coordinator.getCurrentTerm();
                     masterService.submitStateUpdateTask(source,
@@ -1698,7 +1757,7 @@ public class CoordinatorTests extends ESTestCase {
             }
 
             void applyInitialConfiguration() {
-                onNode(localNode, () -> {
+                onNode(() -> {
                     try {
                         coordinator.setInitialConfiguration(initialConfiguration);
                         logger.info("successfully set initial configuration to {}", initialConfiguration);
@@ -1734,7 +1793,7 @@ public class CoordinatorTests extends ESTestCase {
                 public void onNewClusterState(String source, Supplier<ClusterState> clusterStateSupplier, ClusterApplyListener listener) {
                     switch (clusterStateApplyResponse) {
                         case SUCCEED:
-                            deterministicTaskQueue.scheduleNow(onNode(localNode, new Runnable() {
+                            deterministicTaskQueue.scheduleNow(onNode(new Runnable() {
                                 @Override
                                 public void run() {
                                     final ClusterState oldClusterState = clusterApplier.lastAppliedClusterState;
@@ -1754,7 +1813,7 @@ public class CoordinatorTests extends ESTestCase {
                             }));
                             break;
                         case FAIL:
-                            deterministicTaskQueue.scheduleNow(onNode(localNode, new Runnable() {
+                            deterministicTaskQueue.scheduleNow(onNode(new Runnable() {
                                 @Override
                                 public void run() {
                                     listener.onFailure(source, new ElasticsearchException("cluster state application failed"));
@@ -1768,7 +1827,7 @@ public class CoordinatorTests extends ESTestCase {
                             break;
                         case HANG:
                             if (randomBoolean()) {
-                                deterministicTaskQueue.scheduleNow(onNode(localNode, new Runnable() {
+                                deterministicTaskQueue.scheduleNow(onNode(new Runnable() {
                                     @Override
                                     public void run() {
                                         final ClusterState oldClusterState = clusterApplier.lastAppliedClusterState;
@@ -1796,7 +1855,7 @@ public class CoordinatorTests extends ESTestCase {
         }
     }
 
-    public static Runnable onNode(DiscoveryNode node, Runnable runnable) {
+    public static Runnable onNodeLog(DiscoveryNode node, Runnable runnable) {
         final String nodeId = "{" + node.getId() + "}{" + node.getEphemeralId() + "}";
         return new Runnable() {
             @Override
@@ -1880,6 +1939,14 @@ public class CoordinatorTests extends ESTestCase {
         }
     }
 
+    private static DiscoveryNode createDiscoveryNode(int nodeIndex, boolean masterEligible) {
+        final TransportAddress address = buildNewFakeTransportAddress();
+        return new DiscoveryNode("", "node" + nodeIndex,
+            UUIDs.randomBase64UUID(random()), // generated deterministically for repeatable tests
+            address.address().getHostString(), address.getAddress(), address, Collections.emptyMap(),
+            masterEligible ? EnumSet.allOf(Role.class) : emptySet(), Version.CURRENT);
+    }
+
     /**
      * How to behave with a new cluster state
      */

+ 20 - 27
server/src/test/java/org/elasticsearch/snapshots/SnapshotsServiceTests.java

@@ -68,6 +68,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.IndexScopedSettings;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.transport.TransportAddress;
 import org.elasticsearch.common.util.BigArrays;
 import org.elasticsearch.common.util.PageCacheRecycler;
 import org.elasticsearch.common.util.concurrent.PrioritizedEsThreadPoolExecutor;
@@ -114,7 +115,6 @@ import java.util.Map;
 import java.util.Optional;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
-import java.util.function.Predicate;
 import java.util.function.Supplier;
 import java.util.stream.Collectors;
 
@@ -122,7 +122,6 @@ import static java.util.Collections.emptyMap;
 import static java.util.Collections.emptySet;
 import static org.elasticsearch.env.Environment.PATH_HOME_SETTING;
 import static org.elasticsearch.node.Node.NODE_NAME_SETTING;
-import static org.elasticsearch.transport.TransportService.HANDSHAKE_ACTION_NAME;
 import static org.elasticsearch.transport.TransportService.NOOP_TRANSPORT_INTERCEPTOR;
 import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.hasSize;
@@ -388,41 +387,26 @@ public class SnapshotsServiceTests extends ESTestCase {
                         return new MockSinglePrioritizingExecutor(node.getName(), deterministicTaskQueue);
                     }
                 });
-            mockTransport = new DisruptableMockTransport(logger) {
+            mockTransport = new DisruptableMockTransport(node, logger) {
                 @Override
-                protected DiscoveryNode getLocalNode() {
-                    return node;
-                }
-
-                @Override
-                protected ConnectionStatus getConnectionStatus(DiscoveryNode sender, DiscoveryNode destination) {
+                protected ConnectionStatus getConnectionStatus(DiscoveryNode destination) {
                     return ConnectionStatus.CONNECTED;
                 }
 
                 @Override
-                protected Optional<DisruptableMockTransport> getDisruptedCapturingTransport(DiscoveryNode node, String action) {
-                    final Predicate<TestClusterNode> matchesDestination;
-                    if (action.equals(HANDSHAKE_ACTION_NAME)) {
-                        matchesDestination = n -> n.transportService.getLocalNode().getAddress().equals(node.getAddress());
-                    } else {
-                        matchesDestination = n -> n.transportService.getLocalNode().equals(node);
-                    }
-                    return testClusterNodes.nodes.values().stream().filter(matchesDestination).findAny().map(cn -> cn.mockTransport);
+                protected Optional<DisruptableMockTransport> getDisruptableMockTransport(TransportAddress address) {
+                    return testClusterNodes.nodes.values().stream().map(cn -> cn.mockTransport)
+                        .filter(transport -> transport.getLocalNode().getAddress().equals(address))
+                        .findAny();
                 }
 
                 @Override
-                protected void handle(DiscoveryNode sender, DiscoveryNode destination, String action, Runnable doDelivery) {
-                    // handshake needs to run inline as the caller blockingly waits on the result
-                    final Runnable runnable = CoordinatorTests.onNode(destination, doDelivery);
-                    if (action.equals(HANDSHAKE_ACTION_NAME)) {
-                        runnable.run();
-                    } else {
-                        deterministicTaskQueue.scheduleNow(runnable);
-                    }
+                protected void execute(Runnable runnable) {
+                    deterministicTaskQueue.scheduleNow(CoordinatorTests.onNodeLog(getLocalNode(), runnable));
                 }
             };
             transportService = mockTransport.createTransportService(
-                settings, deterministicTaskQueue.getThreadPool(runnable -> CoordinatorTests.onNode(node, runnable)),
+                settings, deterministicTaskQueue.getThreadPool(runnable -> CoordinatorTests.onNodeLog(node, runnable)),
                 NOOP_TRANSPORT_INTERCEPTOR,
                 a -> node, null, emptySet()
             );
@@ -544,7 +528,16 @@ public class SnapshotsServiceTests extends ESTestCase {
             coordinator.start();
             masterService.start();
             clusterService.getClusterApplierService().setNodeConnectionsService(
-                new NodeConnectionsService(clusterService.getSettings(), threadPool, transportService));
+                new NodeConnectionsService(clusterService.getSettings(), threadPool, transportService) {
+                    @Override
+                    public void connectToNodes(DiscoveryNodes discoveryNodes) {
+                        // override this method as it does blocking calls
+                        for (final DiscoveryNode node : discoveryNodes) {
+                            transportService.connectToNode(node);
+                        }
+                        super.connectToNodes(discoveryNodes);
+                    }
+                });
             clusterService.getClusterApplierService().start();
             indicesService.start();
             indicesClusterStateService.start();

+ 91 - 41
test/framework/src/main/java/org/elasticsearch/test/disruption/DisruptableMockTransport.java

@@ -20,80 +20,123 @@ package org.elasticsearch.test.disruption;
 
 import org.apache.logging.log4j.Logger;
 import org.apache.logging.log4j.message.ParameterizedMessage;
+import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.cluster.ClusterModule;
 import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.lease.Releasable;
+import org.elasticsearch.common.settings.ClusterSettings;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.transport.BoundTransportAddress;
+import org.elasticsearch.common.transport.TransportAddress;
 import org.elasticsearch.test.transport.MockTransport;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.transport.CloseableConnection;
 import org.elasticsearch.transport.ConnectTransportException;
+import org.elasticsearch.transport.ConnectionProfile;
 import org.elasticsearch.transport.RequestHandlerRegistry;
 import org.elasticsearch.transport.TransportChannel;
+import org.elasticsearch.transport.TransportException;
+import org.elasticsearch.transport.TransportInterceptor;
 import org.elasticsearch.transport.TransportRequest;
+import org.elasticsearch.transport.TransportRequestOptions;
 import org.elasticsearch.transport.TransportResponse;
+import org.elasticsearch.transport.TransportService;
 
 import java.io.IOException;
 import java.util.Optional;
+import java.util.Set;
+import java.util.function.Function;
 
 import static org.elasticsearch.test.ESTestCase.copyWriteable;
+import static org.elasticsearch.transport.TransportService.HANDSHAKE_ACTION_NAME;
 
 public abstract class DisruptableMockTransport extends MockTransport {
+    private final DiscoveryNode localNode;
     private final Logger logger;
 
-    public DisruptableMockTransport(Logger logger) {
+    public DisruptableMockTransport(DiscoveryNode localNode, Logger logger) {
+        this.localNode = localNode;
         this.logger = logger;
     }
 
-    protected abstract DiscoveryNode getLocalNode();
+    protected abstract ConnectionStatus getConnectionStatus(DiscoveryNode destination);
 
-    protected abstract ConnectionStatus getConnectionStatus(DiscoveryNode sender, DiscoveryNode destination);
+    protected abstract Optional<DisruptableMockTransport> getDisruptableMockTransport(TransportAddress address);
 
-    protected abstract Optional<DisruptableMockTransport> getDisruptedCapturingTransport(DiscoveryNode node, String action);
+    protected abstract void execute(Runnable runnable);
 
-    protected abstract void handle(DiscoveryNode sender, DiscoveryNode destination, String action, Runnable doDelivery);
+    protected final void execute(String action, Runnable runnable) {
+        // handshake needs to run inline as the caller blockingly waits on the result
+        if (action.equals(HANDSHAKE_ACTION_NAME)) {
+            runnable.run();
+        } else {
 
-    protected final void sendFromTo(DiscoveryNode sender, DiscoveryNode destination, String action, Runnable doDelivery) {
-        handle(sender, destination, action, new Runnable() {
-            @Override
-            public void run() {
-                if (getDisruptedCapturingTransport(destination, action).isPresent()) {
-                    doDelivery.run();
-                } else {
-                    logger.trace("unknown destination in {}", this);
-                }
-            }
+            execute(runnable);
+        }
+    }
 
-            @Override
-            public String toString() {
-                return doDelivery.toString();
-            }
-        });
+    public DiscoveryNode getLocalNode() {
+        return localNode;
+    }
+
+    @Override
+    public TransportService createTransportService(Settings settings, ThreadPool threadPool, TransportInterceptor interceptor,
+                                                   Function<BoundTransportAddress, DiscoveryNode> localNodeFactory,
+                                                   @Nullable ClusterSettings clusterSettings, Set<String> taskHeaders) {
+        return new TransportService(settings, this, threadPool, interceptor, localNodeFactory, clusterSettings, taskHeaders);
     }
 
     @Override
-    protected void onSendRequest(long requestId, String action, TransportRequest request, DiscoveryNode destination) {
+    public Releasable openConnection(DiscoveryNode node, ConnectionProfile profile, ActionListener<Connection> listener) {
+        final Optional<DisruptableMockTransport> matchingTransport = getDisruptableMockTransport(node.getAddress());
+        if (matchingTransport.isPresent()) {
+            listener.onResponse(new CloseableConnection() {
+                @Override
+                public DiscoveryNode getNode() {
+                    return node;
+                }
 
-        assert destination.equals(getLocalNode()) == false : "non-local message from " + getLocalNode() + " to itself";
+                @Override
+                public void sendRequest(long requestId, String action, TransportRequest request, TransportRequestOptions options)
+                    throws TransportException {
+                    onSendRequest(requestId, action, request, matchingTransport.get());
+                }
+            });
+            return () -> {};
+        } else {
+            throw new ConnectTransportException(node, "node " + node + " does not exist");
+        }
+    }
 
-        sendFromTo(getLocalNode(), destination, action, new Runnable() {
+    protected void onSendRequest(long requestId, String action, TransportRequest request,
+                                 DisruptableMockTransport destinationTransport) {
+
+        assert destinationTransport.getLocalNode().equals(getLocalNode()) == false :
+            "non-local message from " + getLocalNode() + " to itself";
+
+        execute(action, new Runnable() {
             @Override
             public void run() {
-                switch (getConnectionStatus(getLocalNode(), destination)) {
+                switch (getConnectionStatus(destinationTransport.getLocalNode())) {
                     case BLACK_HOLE:
-                        onBlackholedDuringSend(requestId, action, destination);
+                        onBlackholedDuringSend(requestId, action, destinationTransport);
                         break;
 
                     case DISCONNECTED:
-                        onDisconnectedDuringSend(requestId, action, destination);
+                        onDisconnectedDuringSend(requestId, action, destinationTransport);
                         break;
 
                     case CONNECTED:
-                        onConnectedDuringSend(requestId, action, request, destination);
+                        onConnectedDuringSend(requestId, action, request, destinationTransport);
                         break;
                 }
             }
 
             @Override
             public String toString() {
-                return getRequestDescription(requestId, action, destination);
+                return getRequestDescription(requestId, action, destinationTransport.getLocalNode());
             }
         });
     }
@@ -117,20 +160,27 @@ public abstract class DisruptableMockTransport extends MockTransport {
             requestId, action, getLocalNode(), destination).getFormattedMessage();
     }
 
-    protected void onBlackholedDuringSend(long requestId, String action, DiscoveryNode destination) {
-        logger.trace("dropping {}", getRequestDescription(requestId, action, destination));
+    protected void onBlackholedDuringSend(long requestId, String action, DisruptableMockTransport destinationTransport) {
+        if (action.equals(HANDSHAKE_ACTION_NAME)) {
+            logger.trace("ignoring blackhole and delivering {}",
+                getRequestDescription(requestId, action, destinationTransport.getLocalNode()));
+            // handshakes always have a timeout, and are sent in a blocking fashion, so we must respond with an exception.
+            destinationTransport.execute(action, getDisconnectException(requestId, action, destinationTransport.getLocalNode()));
+        } else {
+            logger.trace("dropping {}", getRequestDescription(requestId, action, destinationTransport.getLocalNode()));
+        }
     }
 
-    protected void onDisconnectedDuringSend(long requestId, String action, DiscoveryNode destination) {
-        sendFromTo(destination, getLocalNode(), action, getDisconnectException(requestId, action, destination));
+    protected void onDisconnectedDuringSend(long requestId, String action, DisruptableMockTransport destinationTransport) {
+        destinationTransport.execute(action, getDisconnectException(requestId, action, destinationTransport.getLocalNode()));
     }
 
-    protected void onConnectedDuringSend(long requestId, String action, TransportRequest request, DiscoveryNode destination) {
-        Optional<DisruptableMockTransport> destinationTransport = getDisruptedCapturingTransport(destination, action);
-        assert destinationTransport.isPresent();
-
+    protected void onConnectedDuringSend(long requestId, String action, TransportRequest request,
+                                         DisruptableMockTransport destinationTransport) {
         final RequestHandlerRegistry<TransportRequest> requestHandler =
-            destinationTransport.get().getRequestHandler(action);
+            destinationTransport.getRequestHandler(action);
+
+        final DiscoveryNode destination = destinationTransport.getLocalNode();
 
         final String requestDescription = getRequestDescription(requestId, action, destination);
 
@@ -147,10 +197,10 @@ public abstract class DisruptableMockTransport extends MockTransport {
 
             @Override
             public void sendResponse(final TransportResponse response) {
-                sendFromTo(destination, getLocalNode(), action, new Runnable() {
+                execute(action, new Runnable() {
                     @Override
                     public void run() {
-                        if (getConnectionStatus(destination, getLocalNode()) != ConnectionStatus.CONNECTED) {
+                        if (destinationTransport.getConnectionStatus(getLocalNode()) != ConnectionStatus.CONNECTED) {
                             logger.trace("dropping response to {}: channel is not CONNECTED",
                                 requestDescription);
                         } else {
@@ -167,10 +217,10 @@ public abstract class DisruptableMockTransport extends MockTransport {
 
             @Override
             public void sendResponse(Exception exception) {
-                sendFromTo(destination, getLocalNode(), action, new Runnable() {
+                execute(action, new Runnable() {
                     @Override
                     public void run() {
-                        if (getConnectionStatus(destination, getLocalNode()) != ConnectionStatus.CONNECTED) {
+                        if (destinationTransport.getConnectionStatus(getLocalNode()) != ConnectionStatus.CONNECTED) {
                             logger.trace("dropping response to {}: channel is not CONNECTED",
                                 requestDescription);
                         } else {

+ 18 - 37
test/framework/src/test/java/org/elasticsearch/test/disruption/DisruptableMockTransportTests.java

@@ -25,6 +25,7 @@ import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.common.collect.Tuple;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.transport.TransportAddress;
 import org.elasticsearch.node.Node;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.disruption.DisruptableMockTransport.ConnectionStatus;
@@ -85,9 +86,6 @@ public class DisruptableMockTransportTests extends ESTestCase {
     public void initTransports() {
         node1 = new DiscoveryNode("node1", buildNewFakeTransportAddress(), Version.CURRENT);
         node2 = new DiscoveryNode("node2", buildNewFakeTransportAddress(), Version.CURRENT);
-        List<DiscoveryNode> discoNodes = new ArrayList<>();
-        discoNodes.add(node1);
-        discoNodes.add(node2);
 
         disconnectedLinks = new HashSet<>();
         blackholedLinks = new HashSet<>();
@@ -97,57 +95,37 @@ public class DisruptableMockTransportTests extends ESTestCase {
         deterministicTaskQueue = new DeterministicTaskQueue(
             Settings.builder().put(Node.NODE_NAME_SETTING.getKey(), "dummy").build(), random());
 
-        transport1 = new DisruptableMockTransport(logger) {
+        transport1 = new DisruptableMockTransport(node1, logger) {
             @Override
-            protected DiscoveryNode getLocalNode() {
-                return node1;
+            protected ConnectionStatus getConnectionStatus(DiscoveryNode destination) {
+                return DisruptableMockTransportTests.this.getConnectionStatus(getLocalNode(), destination);
             }
 
             @Override
-            protected ConnectionStatus getConnectionStatus(DiscoveryNode sender, DiscoveryNode destination) {
-                return DisruptableMockTransportTests.this.getConnectionStatus(sender, destination);
+            protected Optional<DisruptableMockTransport> getDisruptableMockTransport(TransportAddress address) {
+                return transports.stream().filter(t -> t.getLocalNode().getAddress().equals(address)).findAny();
             }
 
             @Override
-            protected Optional<DisruptableMockTransport> getDisruptedCapturingTransport(DiscoveryNode destination, String action) {
-                int index = discoNodes.indexOf(destination);
-                if (index == -1) {
-                    return Optional.empty();
-                } else {
-                    return Optional.of(transports.get(index));
-                }
-            }
-
-            @Override
-            protected void handle(DiscoveryNode sender, DiscoveryNode destination, String action, Runnable doDelivery) {
-                deterministicTaskQueue.scheduleNow(doDelivery);
+            protected void execute(Runnable runnable) {
+                deterministicTaskQueue.scheduleNow(runnable);
             }
         };
 
-        transport2 = new DisruptableMockTransport(logger) {
+        transport2 = new DisruptableMockTransport(node2, logger) {
             @Override
-            protected DiscoveryNode getLocalNode() {
-                return node2;
+            protected ConnectionStatus getConnectionStatus(DiscoveryNode destination) {
+                return DisruptableMockTransportTests.this.getConnectionStatus(getLocalNode(), destination);
             }
 
             @Override
-            protected ConnectionStatus getConnectionStatus(DiscoveryNode sender, DiscoveryNode destination) {
-                return DisruptableMockTransportTests.this.getConnectionStatus(sender, destination);
+            protected Optional<DisruptableMockTransport> getDisruptableMockTransport(TransportAddress address) {
+                return transports.stream().filter(t -> t.getLocalNode().getAddress().equals(address)).findAny();
             }
 
             @Override
-            protected Optional<DisruptableMockTransport> getDisruptedCapturingTransport(DiscoveryNode destination, String action) {
-                int index = discoNodes.indexOf(destination);
-                if (index == -1) {
-                    return Optional.empty();
-                } else {
-                    return Optional.of(transports.get(index));
-                }
-            }
-
-            @Override
-            protected void handle(DiscoveryNode sender, DiscoveryNode destination, String action, Runnable doDelivery) {
-                deterministicTaskQueue.scheduleNow(doDelivery);
+            protected void execute(Runnable runnable) {
+                deterministicTaskQueue.scheduleNow(runnable);
             }
         };
 
@@ -161,6 +139,9 @@ public class DisruptableMockTransportTests extends ESTestCase {
 
         service1.start();
         service2.start();
+
+        service1.connectToNode(node2);
+        service2.connectToNode(node1);
     }