1
0
Эх сурвалжийг харах

Protect NodeConnectionsService from stale conns (#92558)

A call to `ConnectionTarget#connect` which happens strictly after all
calls that close connections should leave us connected to the target.
However concurrent calls to `ConnectionTarget#connect` can overlap, and
today this means that a connection returned from an earlier call may
overwrite one from a later call. The trouble is that the earlier
connection attempt may yield a closed connection (it was concurrent with
the disconnections) so we must not let it supersede the newer one.

With this commit we prevent concurrent connection attempts, which avoids
earlier attempts from overwriting the connections resulting from later
attempts.

When combined with #92546, closes #92029
David Turner 2 жил өмнө
parent
commit
1a650ecab3

+ 6 - 0
docs/changelog/92558.yaml

@@ -0,0 +1,6 @@
+pr: 92558
+summary: Protect `NodeConnectionsService` from stale conns
+area: Network
+type: bug
+issues:
+ - 92029

+ 92 - 44
server/src/main/java/org/elasticsearch/cluster/NodeConnectionsService.java

@@ -21,6 +21,7 @@ import org.elasticsearch.common.inject.Inject;
 import org.elasticsearch.common.settings.Setting;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.concurrent.AbstractRunnable;
+import org.elasticsearch.common.util.concurrent.ListenableFuture;
 import org.elasticsearch.core.Releasable;
 import org.elasticsearch.core.Releasables;
 import org.elasticsearch.core.TimeValue;
@@ -219,6 +220,13 @@ public class NodeConnectionsService extends AbstractLifecycleComponent {
         private final AtomicInteger consecutiveFailureCount = new AtomicInteger();
         private final AtomicReference<Releasable> connectionRef = new AtomicReference<>();
 
+        // all access to these fields is synchronized
+        private ActionListener<Void> pendingListener;
+        private boolean connectionInProgress;
+
+        // placeholder listener for a fire-and-forget connection attempt
+        private static final ActionListener<Void> NOOP = ActionListener.noop();
+
         ConnectionTarget(DiscoveryNode discoveryNode) {
             this.discoveryNode = discoveryNode;
         }
@@ -229,57 +237,97 @@ public class NodeConnectionsService extends AbstractLifecycleComponent {
 
         Runnable connect(ActionListener<Void> listener) {
             return () -> {
-                final boolean alreadyConnected = transportService.nodeConnected(discoveryNode);
+                registerListener(listener);
+                doConnect();
+            };
+        }
 
-                if (alreadyConnected) {
-                    logger.trace("refreshing connection to {}", discoveryNode);
-                } else {
-                    logger.debug("connecting to {}", discoveryNode);
+        private synchronized void registerListener(ActionListener<Void> listener) {
+            if (listener == null) {
+                pendingListener = pendingListener == null ? NOOP : pendingListener;
+            } else if (pendingListener == null || pendingListener == NOOP) {
+                pendingListener = listener;
+            } else if (pendingListener instanceof ListenableFuture<Void> listenableFuture) {
+                listenableFuture.addListener(listener);
+            } else {
+                var wrapper = new ListenableFuture<Void>();
+                wrapper.addListener(pendingListener);
+                wrapper.addListener(listener);
+                pendingListener = wrapper;
+            }
+        }
+
+        private synchronized ActionListener<Void> acquireListener() {
+            // Avoid concurrent connection attempts because they don't necessarily complete in order otherwise, and out-of-order completion
+            // might mean we end up disconnected from a node even though we triggered a call to connect() after all close() calls had
+            // finished.
+            if (connectionInProgress == false) {
+                var listener = pendingListener;
+                if (listener != null) {
+                    pendingListener = null;
+                    connectionInProgress = true;
+                    return listener;
                 }
+            }
+            return null;
+        }
 
-                // It's possible that connectionRef is a reference to an older connection that closed out from under us, but that something
-                // else has opened a fresh connection to the node. Therefore we always call connectToNode() and update connectionRef.
-                transportService.connectToNode(discoveryNode, new ActionListener<>() {
-                    @Override
-                    public void onResponse(Releasable connectionReleasable) {
-                        if (alreadyConnected) {
-                            logger.trace("refreshed connection to {}", discoveryNode);
-                        } else {
-                            logger.debug("connected to {}", discoveryNode);
-                        }
-                        consecutiveFailureCount.set(0);
-                        setConnectionRef(connectionReleasable);
-
-                        final boolean isActive;
-                        synchronized (mutex) {
-                            isActive = targetsByNode.get(discoveryNode) == ConnectionTarget.this;
-                        }
-                        if (isActive == false) {
-                            logger.debug("connected to stale {} - releasing stale connection", discoveryNode);
-                            setConnectionRef(null);
-                        }
-                        if (listener != null) {
-                            listener.onResponse(null);
-                        }
+        private synchronized void releaseListener() {
+            assert connectionInProgress;
+            connectionInProgress = false;
+        }
+
+        private void doConnect() {
+            var listener = acquireListener();
+            if (listener == null) {
+                return;
+            }
+
+            final boolean alreadyConnected = transportService.nodeConnected(discoveryNode);
+
+            if (alreadyConnected) {
+                logger.trace("refreshing connection to {}", discoveryNode);
+            } else {
+                logger.debug("connecting to {}", discoveryNode);
+            }
+
+            // It's possible that connectionRef is a reference to an older connection that closed out from under us, but that something else
+            // has opened a fresh connection to the node. Therefore we always call connectToNode() and update connectionRef.
+            transportService.connectToNode(discoveryNode, ActionListener.runAfter(new ActionListener<>() {
+                @Override
+                public void onResponse(Releasable connectionReleasable) {
+                    if (alreadyConnected) {
+                        logger.trace("refreshed connection to {}", discoveryNode);
+                    } else {
+                        logger.debug("connected to {}", discoveryNode);
                     }
+                    consecutiveFailureCount.set(0);
+                    setConnectionRef(connectionReleasable);
 
-                    @Override
-                    public void onFailure(Exception e) {
-                        final int currentFailureCount = consecutiveFailureCount.incrementAndGet();
-                        // only warn every 6th failure
-                        final Level level = currentFailureCount % 6 == 1 ? Level.WARN : Level.DEBUG;
-                        logger.log(
-                            level,
-                            () -> format("failed to connect to %s (tried [%s] times)", discoveryNode, currentFailureCount),
-                            e
-                        );
+                    final boolean isActive;
+                    synchronized (mutex) {
+                        isActive = targetsByNode.get(discoveryNode) == ConnectionTarget.this;
+                    }
+                    if (isActive == false) {
+                        logger.debug("connected to stale {} - releasing stale connection", discoveryNode);
                         setConnectionRef(null);
-                        if (listener != null) {
-                            listener.onFailure(e);
-                        }
                     }
-                });
-            };
+                    listener.onResponse(null);
+                }
+
+                @Override
+                public void onFailure(Exception e) {
+                    final int currentFailureCount = consecutiveFailureCount.incrementAndGet();
+                    // only warn every 6th failure
+                    final Level level = currentFailureCount % 6 == 1 ? Level.WARN : Level.DEBUG;
+                    logger.log(level, () -> format("failed to connect to %s (tried [%s] times)", discoveryNode, currentFailureCount), e);
+                    setConnectionRef(null);
+                    listener.onFailure(e);
+                }
+            }, () -> {
+                releaseListener();
+                transportService.getThreadPool().generic().execute(this::doConnect);
+            }));
         }
 
         void disconnect() {

+ 88 - 48
server/src/test/java/org/elasticsearch/cluster/NodeConnectionsServiceTests.java

@@ -55,7 +55,9 @@ import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.CyclicBarrier;
+import java.util.concurrent.Semaphore;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicReference;
@@ -97,9 +99,7 @@ public class NodeConnectionsServiceTests extends ESTestCase {
         final AtomicBoolean keepGoing = new AtomicBoolean(true);
         final Thread reconnectionThread = new Thread(() -> {
             while (keepGoing.get()) {
-                final PlainActionFuture<Void> future = new PlainActionFuture<>();
-                service.ensureConnections(() -> future.onResponse(null));
-                future.actionGet();
+                ensureConnections(service);
             }
         }, "reconnection thread");
         reconnectionThread.start();
@@ -109,34 +109,18 @@ public class NodeConnectionsServiceTests extends ESTestCase {
         final boolean isDisrupting = randomBoolean();
         final Thread disruptionThread = new Thread(() -> {
             while (isDisrupting && keepGoing.get()) {
-                final Transport.Connection connection;
-                try {
-                    connection = transportService.getConnection(randomFrom(allNodes));
-                } catch (NodeNotConnectedException e) {
-                    continue;
-                }
-
-                final PlainActionFuture<Void> future = new PlainActionFuture<>();
-                connection.addRemovedListener(future);
-                connection.close();
-                future.actionGet(10, TimeUnit.SECONDS);
+                closeConnection(transportService, randomFrom(allNodes));
             }
         }, "disruption thread");
         disruptionThread.start();
 
         for (int i = 0; i < 10; i++) {
-            final DiscoveryNodes connectNodes = discoveryNodesFromList(randomSubsetOf(allNodes));
-            final PlainActionFuture<Void> future = new PlainActionFuture<>();
-            service.connectToNodes(connectNodes, () -> future.onResponse(null));
-            future.actionGet(10, TimeUnit.SECONDS);
-            final DiscoveryNodes disconnectExceptNodes = discoveryNodesFromList(randomSubsetOf(allNodes));
-            service.disconnectFromNodesExcept(disconnectExceptNodes);
+            connectToNodes(service, discoveryNodesFromList(randomSubsetOf(allNodes)));
+            service.disconnectFromNodesExcept(discoveryNodesFromList(randomSubsetOf(allNodes)));
         }
 
         final DiscoveryNodes nodes = discoveryNodesFromList(randomSubsetOf(allNodes));
-        final PlainActionFuture<Void> connectFuture = new PlainActionFuture<>();
-        service.connectToNodes(nodes, () -> connectFuture.onResponse(null));
-        connectFuture.actionGet(10, TimeUnit.SECONDS);
+        connectToNodes(service, nodes);
         service.disconnectFromNodesExcept(nodes);
 
         assertTrue(keepGoing.compareAndSet(true, false));
@@ -144,14 +128,59 @@ public class NodeConnectionsServiceTests extends ESTestCase {
         disruptionThread.join();
 
         if (isDisrupting) {
-            final PlainActionFuture<Void> ensureFuture = new PlainActionFuture<>();
-            service.ensureConnections(() -> ensureFuture.onResponse(null));
-            ensureFuture.actionGet(10, TimeUnit.SECONDS);
+            ensureConnections(service);
         }
 
+        assertConnected(transportService, nodes);
         assertBusy(() -> assertConnectedExactlyToNodes(nodes));
     }
 
+    public void testConcurrentConnectAndDisconnect() throws Exception {
+        final NodeConnectionsService service = new NodeConnectionsService(Settings.EMPTY, threadPool, transportService);
+
+        final AtomicBoolean keepGoing = new AtomicBoolean(true);
+        final Thread reconnectionThread = new Thread(() -> {
+            while (keepGoing.get()) {
+                ensureConnections(service);
+            }
+        }, "reconnection thread");
+        reconnectionThread.start();
+
+        final var node = new DiscoveryNode("node", buildNewFakeTransportAddress(), Map.of(), Set.of(), Version.CURRENT);
+        final var nodes = discoveryNodesFromList(List.of(node));
+
+        final Thread disruptionThread = new Thread(() -> {
+            while (keepGoing.get()) {
+                closeConnection(transportService, node);
+            }
+        }, "disruption thread");
+        disruptionThread.start();
+
+        final var reconnectPermits = new Semaphore(1000);
+        final var reconnectThreads = 10;
+        final var reconnectCountDown = new CountDownLatch(reconnectThreads);
+        for (int i = 0; i < reconnectThreads; i++) {
+            threadPool.generic().execute(new Runnable() {
+                @Override
+                public void run() {
+                    if (reconnectPermits.tryAcquire()) {
+                        service.connectToNodes(nodes, () -> threadPool.generic().execute(this));
+                    } else {
+                        reconnectCountDown.countDown();
+                    }
+                }
+            });
+        }
+
+        assertTrue(reconnectCountDown.await(10, TimeUnit.SECONDS));
+        assertTrue(keepGoing.compareAndSet(true, false));
+        reconnectionThread.join();
+        disruptionThread.join();
+
+        ensureConnections(service);
+        assertConnectedExactlyToNodes(nodes);
+    }
+
     public void testPeriodicReconnection() {
         final Settings.Builder settings = Settings.builder();
         final long reconnectIntervalMillis;
@@ -234,30 +263,24 @@ public class NodeConnectionsServiceTests extends ESTestCase {
         // connect to one node
         final DiscoveryNode node0 = new DiscoveryNode("node0", buildNewFakeTransportAddress(), Version.CURRENT);
         final DiscoveryNodes nodes0 = DiscoveryNodes.builder().add(node0).build();
-        final PlainActionFuture<Void> future0 = new PlainActionFuture<>();
-        service.connectToNodes(nodes0, () -> future0.onResponse(null));
-        future0.actionGet(10, TimeUnit.SECONDS);
+        connectToNodes(service, nodes0);
         assertConnectedExactlyToNodes(nodes0);
 
         // connection attempts to node0 block indefinitely
         final CyclicBarrier connectionBarrier = new CyclicBarrier(2);
         try {
-            nodeConnectionBlocks.put(node0, connectionBarrier::await);
+            nodeConnectionBlocks.put(node0, () -> connectionBarrier.await(10, TimeUnit.SECONDS));
             transportService.disconnectFromNode(node0);
 
             // can still connect to another node without blocking
             final DiscoveryNode node1 = new DiscoveryNode("node1", buildNewFakeTransportAddress(), Version.CURRENT);
             final DiscoveryNodes nodes1 = DiscoveryNodes.builder().add(node1).build();
             final DiscoveryNodes nodes01 = DiscoveryNodes.builder(nodes0).add(node1).build();
-            final PlainActionFuture<Void> future1 = new PlainActionFuture<>();
-            service.connectToNodes(nodes01, () -> future1.onResponse(null));
-            future1.actionGet(10, TimeUnit.SECONDS);
+            connectToNodes(service, nodes01);
             assertConnectedExactlyToNodes(nodes1);
 
             // can also disconnect from node0 without blocking
-            final PlainActionFuture<Void> future2 = new PlainActionFuture<>();
-            service.connectToNodes(nodes1, () -> future2.onResponse(null));
-            future2.actionGet(10, TimeUnit.SECONDS);
+            connectToNodes(service, nodes1);
             service.disconnectFromNodesExcept(nodes1);
             assertConnectedExactlyToNodes(nodes1);
 
@@ -273,17 +296,15 @@ public class NodeConnectionsServiceTests extends ESTestCase {
 
             // the reconnection is also blocked but the connection future doesn't wait, it completes straight away
             transportService.disconnectFromNode(node0);
-            final PlainActionFuture<Void> future4 = new PlainActionFuture<>();
-            service.connectToNodes(nodes01, () -> future4.onResponse(null));
-            future4.actionGet(10, TimeUnit.SECONDS);
+            connectToNodes(service, nodes01);
             assertConnectedExactlyToNodes(nodes1);
 
             // a blocked reconnection attempt doesn't also block the node from being deregistered
             service.disconnectFromNodesExcept(nodes1);
-            final PlainActionFuture<DiscoveryNode> disconnectFuture1 = new PlainActionFuture<>();
-            assertTrue(disconnectListenerRef.compareAndSet(null, disconnectFuture1));
-            connectionBarrier.await();
-            assertThat(disconnectFuture1.actionGet(10, TimeUnit.SECONDS), equalTo(node0)); // node0 connects briefly, must wait here
+            assertThat(PlainActionFuture.get(disconnectFuture1 -> {
+                assertTrue(disconnectListenerRef.compareAndSet(null, disconnectFuture1));
+                connectionBarrier.await(10, TimeUnit.SECONDS);
+            }, 10, TimeUnit.SECONDS), equalTo(node0)); // node0 connects briefly, must wait here
             assertConnectedExactlyToNodes(nodes1);
 
             // a blocked connection attempt to a new node also doesn't prevent an immediate deregistration
@@ -294,10 +315,10 @@ public class NodeConnectionsServiceTests extends ESTestCase {
             service.disconnectFromNodesExcept(nodes1);
             assertConnectedExactlyToNodes(nodes1);
 
-            final PlainActionFuture<DiscoveryNode> disconnectFuture2 = new PlainActionFuture<>();
-            assertTrue(disconnectListenerRef.compareAndSet(null, disconnectFuture2));
-            connectionBarrier.await(10, TimeUnit.SECONDS);
-            assertThat(disconnectFuture2.actionGet(10, TimeUnit.SECONDS), equalTo(node0)); // node0 connects briefly, must wait here
+            assertThat(PlainActionFuture.get(disconnectFuture2 -> {
+                assertTrue(disconnectListenerRef.compareAndSet(null, disconnectFuture2));
+                connectionBarrier.await(10, TimeUnit.SECONDS);
+            }, 10, TimeUnit.SECONDS), equalTo(node0)); // node0 connects briefly, must wait here
             assertConnectedExactlyToNodes(nodes1);
             assertTrue(future5.isDone());
         } finally {
@@ -310,7 +331,7 @@ public class NodeConnectionsServiceTests extends ESTestCase {
         reason = "testing that DEBUG-level logging is reasonable",
         value = "org.elasticsearch.cluster.NodeConnectionsService:DEBUG"
     )
-    public void testDebugLogging() throws IllegalAccessException {
+    public void testDebugLogging() {
         final DeterministicTaskQueue deterministicTaskQueue = new DeterministicTaskQueue();
 
         MockTransport transport = new MockTransport(deterministicTaskQueue.getThreadPool());
@@ -706,4 +727,23 @@ public class NodeConnectionsServiceTests extends ESTestCase {
             return requestHandlers;
         }
     }
+
+    private static void connectToNodes(NodeConnectionsService service, DiscoveryNodes discoveryNodes) {
+        PlainActionFuture.get(future -> service.connectToNodes(discoveryNodes, () -> future.onResponse(null)), 10, TimeUnit.SECONDS);
+    }
+
+    private static void ensureConnections(NodeConnectionsService service) {
+        PlainActionFuture.get(future -> service.ensureConnections(() -> future.onResponse(null)), 10, TimeUnit.SECONDS);
+    }
+
+    private static void closeConnection(TransportService transportService, DiscoveryNode discoveryNode) {
+        try {
+            final var connection = transportService.getConnection(discoveryNode);
+            connection.close();
+            PlainActionFuture.get(connection::addRemovedListener, 10, TimeUnit.SECONDS);
+        } catch (NodeNotConnectedException e) {
+            // ok
+        }
+    }
+
 }