浏览代码

Ensure pending transport handlers are invoked for all channel failures (#25150)

Today if a channel gets closed due to a disconnect we notify the response
handler that the connection is closed and the node is disconnected. Unfortunately
this is not a complete solution since it only works for published connections.
Connections that are unpublished ie. for discovery can indefinitely hang since we
never invoke their handers when we get a failure while a user is waiting for
the response. This change adds connection tracking to TcpTransport that ensures
we are notifying the corresponding connection if there is a failure on a channel.
Simon Willnauer 8 年之前
父节点
当前提交
186c16ea41

+ 87 - 41
core/src/main/java/org/elasticsearch/transport/TcpTransport.java

@@ -66,6 +66,7 @@ import org.elasticsearch.common.util.concurrent.AbstractRunnable;
 import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
 import org.elasticsearch.common.util.concurrent.KeyedLock;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
+import org.elasticsearch.common.util.iterable.Iterables;
 import org.elasticsearch.indices.breaker.CircuitBreakerService;
 import org.elasticsearch.monitor.jvm.JvmInfo;
 import org.elasticsearch.rest.RestStatus;
@@ -85,7 +86,6 @@ import java.util.Collections;
 import java.util.EnumMap;
 import java.util.HashMap;
 import java.util.HashSet;
-import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
@@ -112,6 +112,7 @@ import static org.elasticsearch.common.settings.Setting.timeSetting;
 import static org.elasticsearch.common.transport.NetworkExceptionHelper.isCloseConnectionException;
 import static org.elasticsearch.common.transport.NetworkExceptionHelper.isConnectException;
 import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentMap;
+import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentSet;
 
 public abstract class TcpTransport<Channel> extends AbstractLifecycleComponent implements Transport {
 
@@ -159,6 +160,8 @@ public abstract class TcpTransport<Channel> extends AbstractLifecycleComponent i
     protected volatile TransportServiceAdapter transportServiceAdapter;
     // node id to actual channel
     protected final ConcurrentMap<DiscoveryNode, NodeChannels> connectedNodes = newConcurrentMap();
+    private final Set<NodeChannels> openConnections = newConcurrentSet();
+
     protected final Map<String, List<Channel>> serverChannels = newConcurrentMap();
     protected final ConcurrentMap<String, BoundTransportAddress> profileBoundAddresses = newConcurrentMap();
 
@@ -357,9 +360,8 @@ public abstract class TcpTransport<Channel> extends AbstractLifecycleComponent i
         private final DiscoveryNode node;
         private final AtomicBoolean closed = new AtomicBoolean(false);
         private final Version version;
-        private final Consumer<Connection> onClose;
 
-        public NodeChannels(DiscoveryNode node, Channel[] channels, ConnectionProfile connectionProfile, Consumer<Connection> onClose) {
+        public NodeChannels(DiscoveryNode node, Channel[] channels, ConnectionProfile connectionProfile) {
             this.node = node;
             this.channels = channels;
             assert channels.length == connectionProfile.getNumConnections() : "expected channels size to be == "
@@ -370,7 +372,6 @@ public abstract class TcpTransport<Channel> extends AbstractLifecycleComponent i
                     typeMapping.put(type, handle);
             }
             version = node.getVersion();
-            this.onClose = onClose;
         }
 
         NodeChannels(NodeChannels channels, Version handshakeVersion) {
@@ -378,7 +379,6 @@ public abstract class TcpTransport<Channel> extends AbstractLifecycleComponent i
             this.channels = channels.channels;
             this.typeMapping = channels.typeMapping;
             this.version = handshakeVersion;
-            this.onClose = channels.onClose;
         }
 
         @Override
@@ -413,7 +413,7 @@ public abstract class TcpTransport<Channel> extends AbstractLifecycleComponent i
                 try {
                     closeChannels(Arrays.stream(channels).filter(Objects::nonNull).collect(Collectors.toList()));
                 } finally {
-                    onClose.accept(this);
+                    onNodeChannelsClosed(this);
                 }
             }
         }
@@ -455,27 +455,28 @@ public abstract class TcpTransport<Channel> extends AbstractLifecycleComponent i
                 if (nodeChannels != null) {
                     return;
                 }
+                boolean success = false;
                 try {
-                    try {
-                        nodeChannels = openConnection(node, connectionProfile);
-                        connectionValidator.accept(nodeChannels, connectionProfile);
-                    } catch (Exception e) {
-                        logger.trace(
-                            (Supplier<?>) () -> new ParameterizedMessage(
-                                "failed to connect to [{}], cleaning dangling connections", node), e);
-                        IOUtils.closeWhileHandlingException(nodeChannels);
-                        throw e;
-                    }
+                    nodeChannels = openConnection(node, connectionProfile);
+                    connectionValidator.accept(nodeChannels, connectionProfile);
                     // we acquire a connection lock, so no way there is an existing connection
                     connectedNodes.put(node, nodeChannels);
                     if (logger.isDebugEnabled()) {
                         logger.debug("connected to node [{}]", node);
                     }
                     transportServiceAdapter.onNodeConnected(node);
+                    success = true;
                 } catch (ConnectTransportException e) {
                     throw e;
                 } catch (Exception e) {
                     throw new ConnectTransportException(node, "general node connection failure", e);
+                } finally {
+                    if (success == false) { // close the connection if there is a failure
+                        logger.trace(
+                            (Supplier<?>) () -> new ParameterizedMessage(
+                                "failed to connect to [{}], cleaning dangling connections", node));
+                        IOUtils.closeWhileHandlingException(nodeChannels);
+                    }
                 }
             }
         } finally {
@@ -518,7 +519,20 @@ public abstract class TcpTransport<Channel> extends AbstractLifecycleComponent i
         try {
             ensureOpen();
             try {
-                nodeChannels = connectToChannels(node, connectionProfile);
+                AtomicBoolean runOnce = new AtomicBoolean(false);
+                Consumer<Channel> onClose = c -> {
+                    assert isOpen(c) == false : "channel is still open when onClose is called";
+                    try {
+                        onChannelClosed(c);
+                    } finally {
+                        // we only need to disconnect from the nodes once since all other channels
+                        // will also try to run this we protect it from running multiple times.
+                        if (runOnce.compareAndSet(false, true)) {
+                            disconnectFromNodeChannel(c, "channel closed");
+                        }
+                    }
+                };
+                nodeChannels = connectToChannels(node, connectionProfile, onClose);
                 final Channel channel = nodeChannels.getChannels().get(0); // one channel is guaranteed by the connection profile
                 final TimeValue connectTimeout = connectionProfile.getConnectTimeout() == null ?
                     defaultConnectionProfile.getConnectTimeout() :
@@ -526,8 +540,9 @@ public abstract class TcpTransport<Channel> extends AbstractLifecycleComponent i
                 final TimeValue handshakeTimeout = connectionProfile.getHandshakeTimeout() == null ?
                     connectTimeout : connectionProfile.getHandshakeTimeout();
                 final Version version = executeHandshake(node, channel, handshakeTimeout);
-                transportServiceAdapter.onConnectionOpened(nodeChannels);
                 nodeChannels = new NodeChannels(nodeChannels, version); // clone the channels - we now have the correct version
+                transportServiceAdapter.onConnectionOpened(nodeChannels);
+                openConnections.add(nodeChannels);
                 success = true;
                 return nodeChannels;
             } catch (ConnectTransportException e) {
@@ -580,24 +595,37 @@ public abstract class TcpTransport<Channel> extends AbstractLifecycleComponent i
     /**
      * Disconnects from a node if a channel is found as part of that nodes channels.
      */
-    protected final void disconnectFromNodeChannel(final Channel channel, final Exception failure) {
+    protected final void disconnectFromNodeChannel(final Channel channel, final String reason) {
         threadPool.generic().execute(() -> {
             try {
-                try {
+                if (isOpen(channel)) {
                     closeChannels(Collections.singletonList(channel));
-                } finally {
-                    for (DiscoveryNode node : connectedNodes.keySet()) {
-                        if (disconnectFromNode(node, channel, ExceptionsHelper.detailedMessage(failure))) {
+                }
+            } catch (IOException e) {
+                logger.warn("failed to close channel", e);
+            } finally {
+                outer:
+                {
+                    for (Map.Entry<DiscoveryNode, NodeChannels> entry : connectedNodes.entrySet()) {
+                        if (disconnectFromNode(entry.getKey(), channel, reason)) {
                             // if we managed to find this channel and disconnect from it, then break, no need to check on
                             // the rest of the nodes
+                            // #onNodeChannelsClosed will remove it..
+                            assert openConnections.contains(entry.getValue()) == false : "NodeChannel#close should remove the connetion";
+                            // we can only be connected and published to a single node with one connection. So if disconnectFromNode
+                            // returns true we can safely break out from here since we cleaned up everything needed
+                            break outer;
+                        }
+                    }
+                    // now if we haven't found the right connection in the connected nodes we have to go through the open connections
+                    // it might be that the channel belongs to a connection that is not published
+                    for (NodeChannels channels : openConnections) {
+                        if (channels.hasChannel(channel)) {
+                            IOUtils.closeWhileHandlingException(channels);
                             break;
                         }
                     }
                 }
-            } catch (IOException e) {
-                logger.warn("failed to close channel", e);
-            } finally {
-                onChannelClosed(channel);
             }
         });
     }
@@ -901,12 +929,11 @@ public abstract class TcpTransport<Channel> extends AbstractLifecycleComponent i
                                 "Error closing serverChannel for profile [{}]", entry.getKey()), e);
                     }
                 }
-
-                for (Iterator<NodeChannels> it = connectedNodes.values().iterator(); it.hasNext();) {
-                    NodeChannels nodeChannels = it.next();
-                    it.remove();
-                    IOUtils.closeWhileHandlingException(nodeChannels);
-                }
+                // we are holding a write lock so nobody modifies the connectedNodes / openConnections map - it's safe to first close
+                // all instances and then clear them maps
+                IOUtils.closeWhileHandlingException(Iterables.concat(connectedNodes.values(), openConnections));
+                openConnections.clear();
+                connectedNodes.clear();
                 stopInternal();
             } finally {
                 globalLock.writeLock().unlock();
@@ -923,11 +950,13 @@ public abstract class TcpTransport<Channel> extends AbstractLifecycleComponent i
     }
 
     protected void onException(Channel channel, Exception e) {
+        String reason = ExceptionsHelper.detailedMessage(e);
         if (!lifecycle.started()) {
             // just close and ignore - we are already stopped and just need to make sure we release all resources
-            disconnectFromNodeChannel(channel, e);
+            disconnectFromNodeChannel(channel, reason);
             return;
         }
+
         if (isCloseConnectionException(e)) {
             logger.trace(
                 (Supplier<?>) () -> new ParameterizedMessage(
@@ -935,15 +964,15 @@ public abstract class TcpTransport<Channel> extends AbstractLifecycleComponent i
                     channel),
                 e);
             // close the channel, which will cause a node to be disconnected if relevant
-            disconnectFromNodeChannel(channel, e);
+            disconnectFromNodeChannel(channel, reason);
         } else if (isConnectException(e)) {
             logger.trace((Supplier<?>) () -> new ParameterizedMessage("connect exception caught on transport layer [{}]", channel), e);
             // close the channel as safe measure, which will cause a node to be disconnected if relevant
-            disconnectFromNodeChannel(channel, e);
+            disconnectFromNodeChannel(channel, reason);
         } else if (e instanceof BindException) {
             logger.trace((Supplier<?>) () -> new ParameterizedMessage("bind exception caught on transport layer [{}]", channel), e);
             // close the channel as safe measure, which will cause a node to be disconnected if relevant
-            disconnectFromNodeChannel(channel, e);
+            disconnectFromNodeChannel(channel, reason);
         } else if (e instanceof CancelledKeyException) {
             logger.trace(
                 (Supplier<?>) () -> new ParameterizedMessage(
@@ -951,7 +980,7 @@ public abstract class TcpTransport<Channel> extends AbstractLifecycleComponent i
                     channel),
                 e);
             // close the channel as safe measure, which will cause a node to be disconnected if relevant
-            disconnectFromNodeChannel(channel, e);
+            disconnectFromNodeChannel(channel, reason);
         } else if (e instanceof TcpTransport.HttpOnTransportException) {
             // in case we are able to return data, serialize the exception content and sent it back to the client
             if (isOpen(channel)) {
@@ -981,7 +1010,7 @@ public abstract class TcpTransport<Channel> extends AbstractLifecycleComponent i
             logger.warn(
                 (Supplier<?>) () -> new ParameterizedMessage("exception caught on transport layer [{}], closing connection", channel), e);
             // close the channel, which will cause a node to be disconnected if relevant
-            disconnectFromNodeChannel(channel, e);
+            disconnectFromNodeChannel(channel, reason);
         }
     }
 
@@ -1012,7 +1041,8 @@ public abstract class TcpTransport<Channel> extends AbstractLifecycleComponent i
      */
     protected abstract void sendMessage(Channel channel, BytesReference reference, ActionListener<Channel> listener);
 
-    protected abstract NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile connectionProfile) throws IOException;
+    protected abstract NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile connectionProfile,
+                                                      Consumer<Channel> onChannelClose) throws IOException;
 
     /**
      * Called to tear down internal resources
@@ -1607,7 +1637,7 @@ public abstract class TcpTransport<Channel> extends AbstractLifecycleComponent i
     /**
      * Called once the channel is closed for instance due to a disconnect or a closed socket etc.
      */
-    protected final void onChannelClosed(Channel channel) {
+    private void onChannelClosed(Channel channel) {
         final Optional<Long> first = pendingHandshakes.entrySet().stream()
             .filter((entry) -> entry.getValue().channel == channel).map((e) -> e.getKey()).findFirst();
         if (first.isPresent()) {
@@ -1655,4 +1685,20 @@ public abstract class TcpTransport<Channel> extends AbstractLifecycleComponent i
             Releasables.close(optionalReleasable, transportAdaptorCallback::run);
         }
     }
+
+    private void onNodeChannelsClosed(NodeChannels channels) {
+        // don't assert here since the channel / connection might not have been registered yet
+        final boolean remove = openConnections.remove(channels);
+        if (remove) {
+            transportServiceAdapter.onConnectionClosed(channels);
+        }
+    }
+
+    final int getNumOpenConnections() {
+        return openConnections.size();
+    }
+
+    final int getNumConnectedNodes() {
+        return connectedNodes.size();
+    }
 }

+ 4 - 3
core/src/test/java/org/elasticsearch/transport/TCPTransportTests.java

@@ -224,8 +224,9 @@ public class TCPTransportTests extends ESTestCase {
                 }
 
                 @Override
-                protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile profile) throws IOException {
-                    return new NodeChannels(node, new Object[profile.getNumConnections()], profile, c -> {});
+                protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile profile,
+                                                         Consumer onChannelClose) throws IOException {
+                    return new NodeChannels(node, new Object[profile.getNumConnections()], profile);
                 }
 
                 @Override
@@ -241,7 +242,7 @@ public class TCPTransportTests extends ESTestCase {
                 @Override
                 public NodeChannels getConnection(DiscoveryNode node) {
                     return new NodeChannels(node, new Object[MockTcpTransport.LIGHT_PROFILE.getNumConnections()],
-                        MockTcpTransport.LIGHT_PROFILE, c -> {});
+                        MockTcpTransport.LIGHT_PROFILE);
                 }
             };
             DiscoveryNode node = new DiscoveryNode("foo", buildNewFakeTransportAddress(), Version.CURRENT);

+ 4 - 23
modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java

@@ -58,7 +58,6 @@ import org.elasticsearch.common.util.BigArrays;
 import org.elasticsearch.common.util.concurrent.EsExecutors;
 import org.elasticsearch.common.util.concurrent.FutureUtils;
 import org.elasticsearch.indices.breaker.CircuitBreakerService;
-import org.elasticsearch.monitor.jvm.JvmInfo;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.ConnectTransportException;
 import org.elasticsearch.transport.ConnectionProfile;
@@ -74,7 +73,6 @@ import java.util.Collections;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
-import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ThreadFactory;
 import java.util.concurrent.TimeUnit;
 import java.util.function.Consumer;
@@ -314,9 +312,9 @@ public class Netty4Transport extends TcpTransport<Channel> {
     }
 
     @Override
-    protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile profile) {
+    protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile profile, Consumer<Channel> onChannelClose) {
         final Channel[] channels = new Channel[profile.getNumConnections()];
-        final NodeChannels nodeChannels = new NodeChannels(node, channels, profile, transportServiceAdapter::onConnectionClosed);
+        final NodeChannels nodeChannels = new NodeChannels(node, channels, profile);
         boolean success = false;
         try {
             final TimeValue connectTimeout;
@@ -336,6 +334,7 @@ public class Netty4Transport extends TcpTransport<Channel> {
                 connections.add(bootstrap.connect(address));
             }
             final Iterator<ChannelFuture> iterator = connections.iterator();
+            final ChannelFutureListener closeListener = future -> onChannelClose.accept(future.channel());
             try {
                 for (int i = 0; i < channels.length; i++) {
                     assert iterator.hasNext();
@@ -345,7 +344,7 @@ public class Netty4Transport extends TcpTransport<Channel> {
                         throw new ConnectTransportException(node, "connect_timeout[" + connectTimeout + "]", future.cause());
                     }
                     channels[i] = future.channel();
-                    channels[i].closeFuture().addListener(new ChannelCloseListener(node));
+                    channels[i].closeFuture().addListener(closeListener);
                 }
                 assert iterator.hasNext() == false : "not all created connection have been consumed";
             } catch (final RuntimeException e) {
@@ -374,24 +373,6 @@ public class Netty4Transport extends TcpTransport<Channel> {
         return nodeChannels;
     }
 
-    private class ChannelCloseListener implements ChannelFutureListener {
-
-        private final DiscoveryNode node;
-
-        private ChannelCloseListener(DiscoveryNode node) {
-            this.node = node;
-        }
-
-        @Override
-        public void operationComplete(final ChannelFuture future) throws Exception {
-            onChannelClosed(future.channel());
-            NodeChannels nodeChannels = connectedNodes.get(node);
-            if (nodeChannels != null && nodeChannels.hasChannel(future.channel())) {
-                threadPool.generic().execute(() -> disconnectFromNode(node, future.channel(), "channel closed event"));
-            }
-        }
-    }
-
     @Override
     protected void sendMessage(Channel channel, BytesReference reference, ActionListener<Channel> listener) {
         final ChannelFuture future = channel.writeAndFlush(Netty4Utils.toByteBuf(reference));

+ 116 - 27
test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java

@@ -41,6 +41,7 @@ import org.elasticsearch.common.unit.TimeValue;
 import org.elasticsearch.common.util.BigArrays;
 import org.elasticsearch.common.util.concurrent.AbstractRunnable;
 import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
+import org.elasticsearch.common.util.concurrent.CountDown;
 import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
 import org.elasticsearch.mocksocket.MockServerSocket;
 import org.elasticsearch.node.Node;
@@ -60,6 +61,7 @@ import java.net.ServerSocket;
 import java.net.Socket;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
@@ -167,6 +169,8 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
         try {
             assertNoPendingHandshakes(serviceA.getOriginalTransport());
             assertNoPendingHandshakes(serviceB.getOriginalTransport());
+            assertPendingConnections(0, serviceA.getOriginalTransport());
+            assertPendingConnections(0, serviceB.getOriginalTransport());
         } finally {
             IOUtils.close(serviceA, serviceB, () -> {
                 try {
@@ -190,6 +194,13 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
         }
     }
 
+    public void assertPendingConnections(int numConnections, Transport transport) {
+        if (transport instanceof TcpTransport) {
+            TcpTransport tcpTransport = (TcpTransport) transport;
+            assertEquals(numConnections, tcpTransport.getNumOpenConnections() - tcpTransport.getNumConnectedNodes());
+        }
+    }
+
     public void testHelloWorld() {
         serviceA.registerRequestHandler("sayHello", StringMessageRequest::new, ThreadPool.Names.GENERIC,
             (request, channel) -> {
@@ -748,11 +759,9 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
 
     public void testNotifyOnShutdown() throws Exception {
         final CountDownLatch latch2 = new CountDownLatch(1);
-
-        serviceA.registerRequestHandler("foobar", StringMessageRequest::new, ThreadPool.Names.GENERIC,
-            new TransportRequestHandler<StringMessageRequest>() {
-                @Override
-                public void messageReceived(StringMessageRequest request, TransportChannel channel) {
+        try {
+            serviceA.registerRequestHandler("foobar", StringMessageRequest::new, ThreadPool.Names.GENERIC,
+                (request, channel) -> {
                     try {
                         latch2.await();
                         logger.info("Stop ServiceB now");
@@ -760,16 +769,19 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
                     } catch (Exception e) {
                         fail(e.getMessage());
                     }
-                }
-            });
-        TransportFuture<TransportResponse.Empty> foobar = serviceB.submitRequest(nodeA, "foobar",
-            new StringMessageRequest(""), TransportRequestOptions.EMPTY, EmptyTransportResponseHandler.INSTANCE_SAME);
-        latch2.countDown();
-        try {
-            foobar.txGet();
-            fail("TransportException expected");
-        } catch (TransportException ex) {
+                });
+            TransportFuture<TransportResponse.Empty> foobar = serviceB.submitRequest(nodeA, "foobar",
+                new StringMessageRequest(""), TransportRequestOptions.EMPTY, EmptyTransportResponseHandler.INSTANCE_SAME);
+            latch2.countDown();
+            try {
+                foobar.txGet();
+                fail("TransportException expected");
+            } catch (TransportException ex) {
 
+            }
+        } finally {
+            serviceB.close(); // make sure we are fully closed here otherwise we might run into assertions down the road
+            serviceA.disconnectFromNode(nodeB);
         }
     }
 
@@ -1469,12 +1481,9 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
 
     public void testMockUnresponsiveRule() throws IOException {
         serviceA.registerRequestHandler("sayHello", StringMessageRequest::new, ThreadPool.Names.GENERIC,
-            new TransportRequestHandler<StringMessageRequest>() {
-                @Override
-                public void messageReceived(StringMessageRequest request, TransportChannel channel) throws Exception {
-                    assertThat("moshe", equalTo(request.message));
-                    throw new RuntimeException("bad message !!!");
-                }
+            (request, channel) -> {
+                assertThat("moshe", equalTo(request.message));
+                throw new RuntimeException("bad message !!!");
             });
 
         serviceB.addUnresponsiveRule(serviceA);
@@ -1852,7 +1861,11 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
         }
         logger.debug("DONE");
         serviceC.close();
-
+        // when we close C here we have to disconnect the service otherwise assertions mit trip with pending connections in tearDown
+        // since the disconnect will then happen concurrently and that might confuse the assertions since we disconnect due to a
+        // connection reset by peer or other exceptions depending on the implementation
+        serviceB.disconnectFromNode(nodeC);
+        serviceA.disconnectFromNode(nodeC);
     }
 
     public void testRegisterHandlerTwice() {
@@ -2137,7 +2150,12 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
             @Override
             public void handleException(TransportException exp) {
                 try {
-                    assertTrue(exp.getClass().toString(), exp instanceof NodeDisconnectedException);
+                    if (exp instanceof SendRequestTransportException) {
+                        assertTrue(exp.getCause().getClass().toString(), exp.getCause() instanceof NodeNotConnectedException);
+                    } else {
+                        // here the concurrent disconnect was faster and invoked the listener first
+                        assertTrue(exp.getClass().toString(), exp instanceof NodeDisconnectedException);
+                    }
                 } finally {
                     latch.countDown();
                 }
@@ -2155,12 +2173,83 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
             TransportRequestOptions.Type.RECOVERY,
             TransportRequestOptions.Type.REG,
             TransportRequestOptions.Type.STATE);
-        Transport.Connection connection = serviceB.openConnection(serviceC.getLocalNode(), builder.build());
-        serviceB.sendRequest(connection, "action",  new TestRequest(randomFrom("fail", "pass")), TransportRequestOptions.EMPTY,
-            transportResponseHandler);
-        connection.close();
+        try (Transport.Connection connection = serviceB.openConnection(serviceC.getLocalNode(), builder.build())) {
+            serviceC.close();
+            serviceB.sendRequest(connection, "action", new TestRequest("boom"), TransportRequestOptions.EMPTY,
+                transportResponseHandler);
+        }
         latch.await();
-        serviceC.close();
+    }
+
+    public void testConcurrentDisconnectOnNonPublishedConnection() throws IOException, InterruptedException {
+        MockTransportService serviceC = build(Settings.builder().put("name", "TS_TEST").build(), version0, null, true);
+        CountDownLatch receivedLatch = new CountDownLatch(1);
+        CountDownLatch sendResponseLatch = new CountDownLatch(1);
+        serviceC.registerRequestHandler("action", TestRequest::new, ThreadPool.Names.SAME,
+            (request, channel) -> {
+                // don't block on a network thread here
+                threadPool.generic().execute(new AbstractRunnable() {
+                    @Override
+                    public void onFailure(Exception e) {
+                        try {
+                            channel.sendResponse(e);
+                        } catch (IOException e1) {
+                            throw new UncheckedIOException(e1);
+                        }
+                    }
+
+                    @Override
+                    protected void doRun() throws Exception {
+                        receivedLatch.countDown();
+                        sendResponseLatch.await();
+                        channel.sendResponse(TransportResponse.Empty.INSTANCE);
+                    }
+                });
+            });
+        serviceC.start();
+        serviceC.acceptIncomingRequests();
+        CountDownLatch responseLatch = new CountDownLatch(1);
+        TransportResponseHandler<TransportResponse> transportResponseHandler = new TransportResponseHandler<TransportResponse>() {
+            @Override
+            public TransportResponse newInstance() {
+                return TransportResponse.Empty.INSTANCE;
+            }
+
+            @Override
+            public void handleResponse(TransportResponse response) {
+                responseLatch.countDown();
+            }
+
+            @Override
+            public void handleException(TransportException exp) {
+                responseLatch.countDown();
+            }
+
+            @Override
+            public String executor() {
+                return ThreadPool.Names.SAME;
+            }
+        };
+
+        ConnectionProfile.Builder builder = new ConnectionProfile.Builder();
+        builder.addConnections(1,
+            TransportRequestOptions.Type.BULK,
+            TransportRequestOptions.Type.PING,
+            TransportRequestOptions.Type.RECOVERY,
+            TransportRequestOptions.Type.REG,
+            TransportRequestOptions.Type.STATE);
+
+        try (Transport.Connection connection = serviceB.openConnection(serviceC.getLocalNode(), builder.build())) {
+            serviceB.sendRequest(connection, "action", new TestRequest("hello world"), TransportRequestOptions.EMPTY,
+                transportResponseHandler);
+            receivedLatch.await();
+            assertPendingConnections(1, serviceB.getOriginalTransport());
+            serviceC.close();
+            assertPendingConnections(0, serviceC.getOriginalTransport());
+            sendResponseLatch.countDown();
+            responseLatch.await();
+        }
+        assertPendingConnections(0, serviceC.getOriginalTransport());
     }
 
 }

+ 4 - 18
test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java

@@ -60,7 +60,6 @@ import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.Executor;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
-import java.util.concurrent.RejectedExecutionException;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.function.Consumer;
@@ -178,25 +177,13 @@ public class MockTcpTransport extends TcpTransport<MockTcpTransport.MockChannel>
     }
 
     @Override
-    protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile profile) throws IOException {
+    protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile profile,
+                                             Consumer<MockChannel> onChannelClose) throws IOException {
         final MockChannel[] mockChannels = new MockChannel[1];
-        final NodeChannels nodeChannels = new NodeChannels(node, mockChannels, LIGHT_PROFILE,
-            transportServiceAdapter::onConnectionClosed); // we always use light here
+        final NodeChannels nodeChannels = new NodeChannels(node, mockChannels, LIGHT_PROFILE); // we always use light here
         boolean success = false;
         final MockSocket socket = new MockSocket();
         try {
-            Consumer<MockChannel> onClose = (channel) -> {
-                final NodeChannels connected = connectedNodes.get(node);
-                if (connected != null && connected.hasChannel(channel)) {
-                    try {
-                        executor.execute(() -> {
-                            disconnectFromNode(node, channel, "channel closed event");
-                        });
-                    } catch (RejectedExecutionException ex) {
-                        logger.debug("failed to run disconnectFromNode - node is shutting down");
-                    }
-                }
-            };
             final InetSocketAddress address = node.getAddress().address();
             // we just use a single connections
             configureSocket(socket);
@@ -206,7 +193,7 @@ public class MockTcpTransport extends TcpTransport<MockTcpTransport.MockChannel>
             } catch (SocketTimeoutException ex) {
                 throw new ConnectTransportException(node, "connect_timeout[" + connectTimeout + "]", ex);
             }
-            MockChannel channel = new MockChannel(socket, address, "none", onClose);
+            MockChannel channel = new MockChannel(socket, address, "none", onChannelClose);
             channel.loopRead(executor);
             mockChannels[0] = channel;
             success = true;
@@ -376,7 +363,6 @@ public class MockTcpTransport extends TcpTransport<MockTcpTransport.MockChannel>
                 synchronized (openChannels) {
                     removedChannel = openChannels.remove(this);
                 }
-                onChannelClosed(this);
                 IOUtils.close(serverSocket, activeChannel, () -> IOUtils.close(workerChannels),
                     () -> cancellableThreads.cancel("channel closed"), onClose);
                 assert removedChannel: "Channel was not removed or removed twice?";