1
0
Эх сурвалжийг харах

Drain responses on completion for TransportNodesAction (#130303)

This PR ensures the node responses are copied and drained exclusively in
onCompletion so that they do not get concurrently modified by
cancellation.

Resolves: #128852
Yang Wang 3 сар өмнө
parent
commit
74fd66c1f1

+ 5 - 0
docs/changelog/130303.yaml

@@ -0,0 +1,5 @@
+pr: 130303
+summary: Drain responses on completion for `TransportNodesAction`
+area: Distributed
+type: bug
+issues: []

+ 20 - 8
server/src/main/java/org/elasticsearch/action/support/nodes/TransportNodesAction.java

@@ -42,6 +42,7 @@ import java.util.ArrayList;
 import java.util.List;
 import java.util.Objects;
 import java.util.concurrent.Executor;
+import java.util.concurrent.atomic.AtomicBoolean;
 
 import static org.elasticsearch.core.Strings.format;
 
@@ -99,6 +100,7 @@ public abstract class TransportNodesAction<
             final ActionContext actionContext = createActionContext(task, request);
             final ArrayList<NodeResponse> responses = new ArrayList<>(concreteNodes.length);
             final ArrayList<FailedNodeException> exceptions = new ArrayList<>(0);
+            final AtomicBoolean responsesHandled = new AtomicBoolean(false);
 
             final TransportRequestOptions transportRequestOptions = TransportRequestOptions.timeout(request.timeout());
 
@@ -109,12 +111,14 @@ public abstract class TransportNodesAction<
             private void addReleaseOnCancellationListener() {
                 if (task instanceof CancellableTask cancellableTask) {
                     cancellableTask.addListener(() -> {
-                        final List<NodeResponse> drainedResponses;
-                        synchronized (responses) {
-                            drainedResponses = List.copyOf(responses);
-                            responses.clear();
+                        if (responsesHandled.compareAndSet(false, true)) {
+                            final List<NodeResponse> drainedResponses;
+                            synchronized (responses) {
+                                drainedResponses = List.copyOf(responses);
+                                responses.clear();
+                            }
+                            Releasables.wrap(Iterators.map(drainedResponses.iterator(), r -> r::decRef)).close();
                         }
-                        Releasables.wrap(Iterators.map(drainedResponses.iterator(), r -> r::decRef)).close();
                     });
                 }
             }
@@ -161,10 +165,18 @@ public abstract class TransportNodesAction<
 
             @Override
             protected CheckedConsumer<ActionListener<NodesResponse>, Exception> onCompletion() {
-                // ref releases all happen-before here so no need to be synchronized
                 return l -> {
-                    try (var ignored = Releasables.wrap(Iterators.map(responses.iterator(), r -> r::decRef))) {
-                        newResponseAsync(task, request, actionContext, responses, exceptions, l);
+                    if (responsesHandled.compareAndSet(false, true)) {
+                        // ref releases all happen-before here so no need to be synchronized
+                        try (var ignored = Releasables.wrap(Iterators.map(responses.iterator(), r -> r::decRef))) {
+                            newResponseAsync(task, request, actionContext, responses, exceptions, l);
+                        }
+                    } else {
+                        logger.debug("task cancelled after all responses were collected");
+                        assert task instanceof CancellableTask : "expect CancellableTask, but got: " + task;
+                        final var cancellableTask = (CancellableTask) task;
+                        assert cancellableTask.isCancelled();
+                        cancellableTask.notifyIfCancelled(l);
                     }
                 };
             }

+ 136 - 0
server/src/test/java/org/elasticsearch/action/support/nodes/TransportNodesActionTests.java

@@ -10,6 +10,7 @@
 package org.elasticsearch.action.support.nodes;
 
 import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.ExceptionsHelper;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.FailedNodeException;
 import org.elasticsearch.action.support.ActionFilters;
@@ -57,6 +58,8 @@ import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.CyclicBarrier;
 import java.util.concurrent.Executor;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
@@ -66,7 +69,9 @@ import java.util.function.ObjLongConsumer;
 import static java.util.Collections.emptyMap;
 import static org.elasticsearch.test.ClusterServiceUtils.createClusterService;
 import static org.elasticsearch.test.ClusterServiceUtils.setState;
+import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.greaterThan;
+import static org.hamcrest.Matchers.hasSize;
 import static org.mockito.Mockito.mock;
 
 public class TransportNodesActionTests extends ESTestCase {
@@ -316,6 +321,137 @@ public class TransportNodesActionTests extends ESTestCase {
         assertTrue(cancellableTask.isCancelled()); // keep task alive
     }
 
+    public void testCompletionShouldNotBeInterferedByCancellationAfterProcessingBegins() throws Exception {
+        final var barrier = new CyclicBarrier(2);
+        final var action = new TestTransportNodesAction(
+            clusterService,
+            transportService,
+            new ActionFilters(Set.of()),
+            TestNodeRequest::new,
+            THREAD_POOL.executor(ThreadPool.Names.GENERIC)
+        ) {
+            @Override
+            protected void newResponseAsync(
+                Task task,
+                TestNodesRequest request,
+                Void unused,
+                List<TestNodeResponse> testNodeResponses,
+                List<FailedNodeException> failures,
+                ActionListener<TestNodesResponse> listener
+            ) {
+                boolean waited = false;
+                // Process node responses in a loop and ensure no ConcurrentModificationException will be thrown due to
+                // concurrent cancellation coming after the loop has started, see also #128852
+                for (var response : testNodeResponses) {
+                    if (waited == false) {
+                        waited = true;
+                        safeAwait(barrier);
+                        safeAwait(barrier);
+                    }
+                }
+                super.newResponseAsync(task, request, unused, testNodeResponses, failures, listener);
+            }
+        };
+
+        final CancellableTask cancellableTask = new CancellableTask(randomLong(), "transport", "action", "", null, emptyMap());
+        final var cancelledFuture = new PlainActionFuture<Void>();
+        cancellableTask.addListener(() -> cancelledFuture.onResponse(null));
+
+        final PlainActionFuture<TestNodesResponse> future = new PlainActionFuture<>();
+        action.execute(cancellableTask, new TestNodesRequest(), future);
+
+        for (var capturedRequest : transport.getCapturedRequestsAndClear()) {
+            completeOneRequest(capturedRequest);
+        }
+
+        // Wait for the overall response to start processing the node responses in a loop and then cancel the task.
+        // The cancellation should not interfere with the node response processing.
+        safeAwait(barrier);
+        TaskCancelHelper.cancel(cancellableTask, "simulated");
+        safeGet(cancelledFuture);
+
+        // Let the process continue, and it should be successful
+        safeAwait(barrier);
+        assertResponseReleased(safeGet(future));
+    }
+
+    public void testConcurrentlyCompletionAndCancellation() throws InterruptedException {
+        final var action = getTestTransportNodesAction();
+
+        final CountDownLatch onCancelledLatch = new CountDownLatch(1);
+        final CancellableTask cancellableTask = new CancellableTask(randomLong(), "transport", "action", "", null, emptyMap()) {
+            @Override
+            protected void onCancelled() {
+                onCancelledLatch.countDown();
+            }
+        };
+
+        final PlainActionFuture<TestNodesResponse> future = new PlainActionFuture<>();
+        action.execute(cancellableTask, new TestNodesRequest(), future);
+
+        final List<TestNodeResponse> nodeResponses = new ArrayList<>();
+        final CapturingTransport.CapturedRequest[] capturedRequests = transport.getCapturedRequestsAndClear();
+        for (int i = 0; i < capturedRequests.length - 1; i++) {
+            final var capturedRequest = capturedRequests[i];
+            nodeResponses.add(completeOneRequest(capturedRequest));
+        }
+
+        final var raceBarrier = new CyclicBarrier(3);
+        final Thread completeThread = new Thread(() -> {
+            safeAwait(raceBarrier);
+            nodeResponses.add(completeOneRequest(capturedRequests[capturedRequests.length - 1]));
+        });
+        final Thread cancelThread = new Thread(() -> {
+            safeAwait(raceBarrier);
+            TaskCancelHelper.cancel(cancellableTask, "simulated");
+        });
+        completeThread.start();
+        cancelThread.start();
+        safeAwait(raceBarrier);
+
+        // We expect either a successful response or a cancellation exception. All node responses should be released in both cases.
+        try {
+            final var testNodesResponse = future.actionGet(SAFE_AWAIT_TIMEOUT);
+            assertThat(testNodesResponse.getNodes(), hasSize(capturedRequests.length));
+            assertResponseReleased(testNodesResponse);
+        } catch (Exception e) {
+            final var taskCancelledException = (TaskCancelledException) ExceptionsHelper.unwrap(e, TaskCancelledException.class);
+            assertNotNull("expect task cancellation exception, but got\n" + ExceptionsHelper.stackTrace(e), taskCancelledException);
+            assertThat(e.getMessage(), containsString("task cancelled [simulated]"));
+            assertTrue(cancellableTask.isCancelled());
+            safeAwait(onCancelledLatch); // wait for the latch, the listener for releasing node responses is called before it
+            assertTrue(nodeResponses.stream().allMatch(r -> r.hasReferences() == false));
+        }
+
+        completeThread.join(10_000);
+        cancelThread.join(10_000);
+        assertFalse(completeThread.isAlive());
+        assertFalse(cancelThread.isAlive());
+    }
+
+    private void assertResponseReleased(TestNodesResponse response) {
+        final var allResponsesReleasedListener = new SubscribableListener<Void>();
+        try (var listeners = new RefCountingListener(allResponsesReleasedListener)) {
+            response.addCloseListener(listeners.acquire());
+            for (final var nodeResponse : response.getNodes()) {
+                nodeResponse.addCloseListener(listeners.acquire());
+            }
+        }
+        safeAwait(allResponsesReleasedListener);
+        assertTrue(response.getNodes().stream().noneMatch(TestNodeResponse::hasReferences));
+        assertFalse(response.hasReferences());
+    }
+
+    private TestNodeResponse completeOneRequest(CapturingTransport.CapturedRequest capturedRequest) {
+        final var response = new TestNodeResponse(capturedRequest.node());
+        try {
+            transport.getTransportResponseHandler(capturedRequest.requestId()).handleResponse(response);
+        } finally {
+            response.decRef();
+        }
+        return response;
+    }
+
     @BeforeClass
     public static void startThreadPool() {
         THREAD_POOL = new TestThreadPool(TransportNodesActionTests.class.getSimpleName());