Browse Source

Fix race condition in `RemoteClusterService.collectNodes()` (#131937)

It is possible for a linked remote to get unlinked in
between the containsKey() and get() calls in collectNodes().
This change adds a test that produces the NullPointerException
and adds a fix.
Jeremy Dahlgren 2 months ago
parent
commit
ccf9893bd1

+ 5 - 0
docs/changelog/131937.yaml

@@ -0,0 +1,5 @@
+pr: 131937
+summary: Fix race condition in `RemoteClusterService.collectNodes()`
+area: Distributed
+type: bug
+issues: []

+ 13 - 23
server/src/main/java/org/elasticsearch/transport/RemoteClusterService.java

@@ -17,6 +17,7 @@ import org.elasticsearch.action.OriginalIndices;
 import org.elasticsearch.action.support.CountDownActionListener;
 import org.elasticsearch.action.support.IndicesOptions;
 import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.action.support.RefCountingListener;
 import org.elasticsearch.action.support.RefCountingRunnable;
 import org.elasticsearch.client.internal.RemoteClusterClient;
 import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
@@ -29,7 +30,6 @@ import org.elasticsearch.common.settings.SecureString;
 import org.elasticsearch.common.settings.Setting;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
-import org.elasticsearch.common.util.concurrent.CountDown;
 import org.elasticsearch.common.util.concurrent.EsExecutors;
 import org.elasticsearch.core.IOUtils;
 import org.elasticsearch.core.TimeValue;
@@ -567,36 +567,26 @@ public final class RemoteClusterService extends RemoteClusterAware
                 "this node does not have the " + DiscoveryNodeRole.REMOTE_CLUSTER_CLIENT_ROLE.roleName() + " role"
             );
         }
+        final var connectionsMap = new HashMap<String, RemoteClusterConnection>();
         for (String cluster : clusters) {
-            if (this.remoteClusters.containsKey(cluster) == false) {
+            final var connection = this.remoteClusters.get(cluster);
+            if (connection == null) {
                 listener.onFailure(new NoSuchRemoteClusterException(cluster));
                 return;
             }
+            connectionsMap.put(cluster, connection);
         }
 
         final Map<String, Function<String, DiscoveryNode>> clusterMap = new HashMap<>();
-        CountDown countDown = new CountDown(clusters.size());
-        Function<String, DiscoveryNode> nullFunction = s -> null;
-        for (final String cluster : clusters) {
-            RemoteClusterConnection connection = this.remoteClusters.get(cluster);
-            connection.collectNodes(new ActionListener<Function<String, DiscoveryNode>>() {
-                @Override
-                public void onResponse(Function<String, DiscoveryNode> nodeLookup) {
-                    synchronized (clusterMap) {
-                        clusterMap.put(cluster, nodeLookup);
-                    }
-                    if (countDown.countDown()) {
-                        listener.onResponse((clusterAlias, nodeId) -> clusterMap.getOrDefault(clusterAlias, nullFunction).apply(nodeId));
-                    }
-                }
-
-                @Override
-                public void onFailure(Exception e) {
-                    if (countDown.fastForward()) { // we need to check if it's true since we could have multiple failures
-                        listener.onFailure(e);
-                    }
+        final var finalListener = listener.<Void>safeMap(
+            ignored -> (clusterAlias, nodeId) -> clusterMap.getOrDefault(clusterAlias, s -> null).apply(nodeId)
+        );
+        try (var refs = new RefCountingListener(finalListener)) {
+            connectionsMap.forEach((cluster, connection) -> connection.collectNodes(refs.acquire(nodeLookup -> {
+                synchronized (clusterMap) {
+                    clusterMap.put(cluster, nodeLookup);
                 }
-            });
+            })));
         }
     }
 

+ 81 - 0
server/src/test/java/org/elasticsearch/transport/RemoteClusterServiceTests.java

@@ -9,8 +9,10 @@
 package org.elasticsearch.transport;
 
 import org.apache.logging.log4j.Level;
+import org.apache.lucene.store.AlreadyClosedException;
 import org.elasticsearch.TransportVersion;
 import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.LatchedActionListener;
 import org.elasticsearch.action.OriginalIndices;
 import org.elasticsearch.action.support.ActionTestUtils;
 import org.elasticsearch.action.support.IndicesOptions;
@@ -1060,6 +1062,85 @@ public class RemoteClusterServiceTests extends ESTestCase {
         }
     }
 
+    public void testCollectNodesConcurrentWithSettingsChanges() throws IOException {
+        final List<DiscoveryNode> knownNodes_c1 = new CopyOnWriteArrayList<>();
+
+        try (
+            var c1N1 = startTransport(
+                "cluster_1_node_1",
+                knownNodes_c1,
+                VersionInformation.CURRENT,
+                TransportVersion.current(),
+                Settings.EMPTY
+            );
+            var transportService = MockTransportService.createNewService(
+                Settings.EMPTY,
+                VersionInformation.CURRENT,
+                TransportVersion.current(),
+                threadPool,
+                null
+            )
+        ) {
+            final var c1N1Node = c1N1.getLocalNode();
+            knownNodes_c1.add(c1N1Node);
+            final var seedList = List.of(c1N1Node.getAddress().toString());
+            transportService.start();
+            transportService.acceptIncomingRequests();
+
+            try (RemoteClusterService service = new RemoteClusterService(createSettings("cluster_1", seedList), transportService)) {
+                service.initializeRemoteClusters();
+                assertTrue(service.isCrossClusterSearchEnabled());
+                final var numTasks = between(3, 5);
+                final var taskLatch = new CountDownLatch(numTasks);
+
+                ESTestCase.startInParallel(numTasks, threadNumber -> {
+                    if (threadNumber == 0) {
+                        taskLatch.countDown();
+                        boolean isLinked = true;
+                        while (taskLatch.getCount() != 0) {
+                            final var future = new PlainActionFuture<RemoteClusterService.RemoteClusterConnectionStatus>();
+                            final var settings = createSettings("cluster_1", isLinked ? Collections.emptyList() : seedList);
+                            service.updateRemoteCluster("cluster_1", settings, future);
+                            safeGet(future);
+                            isLinked = isLinked == false;
+                        }
+                        return;
+                    }
+
+                    // Verify collectNodes() always invokes the listener, even if the node is concurrently being unlinked.
+                    try {
+                        for (int i = 0; i < 10; ++i) {
+                            final var latch = new CountDownLatch(1);
+                            final var exRef = new AtomicReference<Exception>();
+                            service.collectNodes(Set.of("cluster_1"), new LatchedActionListener<>(new ActionListener<>() {
+                                @Override
+                                public void onResponse(BiFunction<String, String, DiscoveryNode> func) {
+                                    assertEquals(c1N1Node, func.apply("cluster_1", c1N1Node.getId()));
+                                }
+
+                                @Override
+                                public void onFailure(Exception e) {
+                                    exRef.set(e);
+                                }
+                            }, latch));
+                            safeAwait(latch);
+                            if (exRef.get() != null) {
+                                assertThat(
+                                    exRef.get(),
+                                    either(instanceOf(TransportException.class)).or(instanceOf(NoSuchRemoteClusterException.class))
+                                        .or(instanceOf(AlreadyClosedException.class))
+                                        .or(instanceOf(NoSeedNodeLeftException.class))
+                                );
+                            }
+                        }
+                    } finally {
+                        taskLatch.countDown();
+                    }
+                });
+            }
+        }
+    }
+
     public void testRemoteClusterSkipIfDisconnectedSetting() {
         {
             Settings settings = Settings.builder()