1
0
Эх сурвалжийг харах

Discard intermediate node results when a request is cancelled (#82685)

Resolves #82337
Mary Gouseti 3 жил өмнө
parent
commit
d4655e8801

+ 6 - 0
docs/changelog/82685.yaml

@@ -0,0 +1,6 @@
+pr: 82685
+summary: Discard intermediate results upon cancellation for stats endpoints
+area: Stats
+type: bug
+issues:
+ - 82337

+ 97 - 0
server/src/main/java/org/elasticsearch/action/support/NodeResponseTracker.java

@@ -0,0 +1,97 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.action.support;
+
+import java.util.Collection;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicReferenceArray;
+
+/**
+ * This class tracks the intermediate responses that will be used to create aggregated cluster response to a request. It also gives the
+ * possibility to discard the intermediate results when asked, for example when the initial request is cancelled, in order to release the
+ * resources.
+ */
+public class NodeResponseTracker {
+
+    private final AtomicInteger counter = new AtomicInteger();
+    private final int expectedResponsesCount;
+    private volatile AtomicReferenceArray<Object> responses;
+    private volatile Exception causeOfDiscarding;
+
+    public NodeResponseTracker(int size) {
+        this.expectedResponsesCount = size;
+        this.responses = new AtomicReferenceArray<>(size);
+    }
+
+    public NodeResponseTracker(Collection<Object> array) {
+        this.expectedResponsesCount = array.size();
+        this.responses = new AtomicReferenceArray<>(array.toArray());
+    }
+
+    /**
+     * This method discards the results collected so far to free up the resources.
+     * @param cause the discarding, this will be communicated if they try to access the discarded results
+     */
+    public void discardIntermediateResponses(Exception cause) {
+        if (responses != null) {
+            this.causeOfDiscarding = cause;
+            responses = null;
+        }
+    }
+
+    public boolean responsesDiscarded() {
+        return responses == null;
+    }
+
+    /**
+     * This method stores a new node response if the intermediate responses haven't been discarded yet. If the responses are not discarded
+     * the method asserts that this is the first response encountered from this node to protect from miscounting the responses in case of a
+     * double invocation. If the responses have been discarded we accept this risk for simplicity.
+     * @param nodeIndex, the index that represents a single node of the cluster
+     * @param response, a response can be either a NodeResponse or an error
+     * @return true if all the nodes' responses have been received, else false
+     */
+    public boolean trackResponseAndCheckIfLast(int nodeIndex, Object response) {
+        AtomicReferenceArray<Object> responses = this.responses;
+
+        if (responsesDiscarded() == false) {
+            boolean firstEncounter = responses.compareAndSet(nodeIndex, null, response);
+            assert firstEncounter : "a response should be tracked only once";
+        }
+        return counter.incrementAndGet() == getExpectedResponseCount();
+    }
+
+    /**
+     * Returns the tracked response or null if the response hasn't been received yet for a specific index that represents a node of the
+     * cluster.
+     * @throws DiscardedResponsesException if the responses have been discarded
+     */
+    public Object getResponse(int nodeIndex) throws DiscardedResponsesException {
+        AtomicReferenceArray<Object> responses = this.responses;
+        if (responsesDiscarded()) {
+            throw new DiscardedResponsesException(causeOfDiscarding);
+        }
+        return responses.get(nodeIndex);
+    }
+
+    public int getExpectedResponseCount() {
+        return expectedResponsesCount;
+    }
+
+    /**
+     * This exception is thrown when the {@link NodeResponseTracker} is asked to give information about the responses after they have been
+     * discarded.
+     */
+    public static class DiscardedResponsesException extends Exception {
+
+        public DiscardedResponsesException(Exception cause) {
+            super(cause);
+        }
+    }
+}

+ 45 - 31
server/src/main/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeAction.java

@@ -16,6 +16,7 @@ import org.elasticsearch.action.support.ActionFilters;
 import org.elasticsearch.action.support.DefaultShardOperationFailedException;
 import org.elasticsearch.action.support.HandledTransportAction;
 import org.elasticsearch.action.support.IndicesOptions;
+import org.elasticsearch.action.support.NodeResponseTracker;
 import org.elasticsearch.action.support.TransportActions;
 import org.elasticsearch.action.support.broadcast.BroadcastRequest;
 import org.elasticsearch.action.support.broadcast.BroadcastResponse;
@@ -51,7 +52,6 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.atomic.AtomicInteger;
-import java.util.concurrent.atomic.AtomicReferenceArray;
 import java.util.function.Consumer;
 
 /**
@@ -118,28 +118,29 @@ public abstract class TransportBroadcastByNodeAction<
 
     private Response newResponse(
         Request request,
-        AtomicReferenceArray<?> responses,
+        NodeResponseTracker nodeResponseTracker,
         int unavailableShardCount,
         Map<String, List<ShardRouting>> nodes,
         ClusterState clusterState
-    ) {
+    ) throws NodeResponseTracker.DiscardedResponsesException {
         int totalShards = 0;
         int successfulShards = 0;
         List<ShardOperationResult> broadcastByNodeResponses = new ArrayList<>();
         List<DefaultShardOperationFailedException> exceptions = new ArrayList<>();
-        for (int i = 0; i < responses.length(); i++) {
-            if (responses.get(i)instanceof FailedNodeException exception) {
+        for (int i = 0; i < nodeResponseTracker.getExpectedResponseCount(); i++) {
+            Object response = nodeResponseTracker.getResponse(i);
+            if (response instanceof FailedNodeException exception) {
                 totalShards += nodes.get(exception.nodeId()).size();
                 for (ShardRouting shard : nodes.get(exception.nodeId())) {
                     exceptions.add(new DefaultShardOperationFailedException(shard.getIndexName(), shard.getId(), exception));
                 }
             } else {
                 @SuppressWarnings("unchecked")
-                NodeResponse response = (NodeResponse) responses.get(i);
-                broadcastByNodeResponses.addAll(response.results);
-                totalShards += response.getTotalShards();
-                successfulShards += response.getSuccessfulShards();
-                for (BroadcastShardOperationFailedException throwable : response.getExceptions()) {
+                NodeResponse nodeResponse = (NodeResponse) response;
+                broadcastByNodeResponses.addAll(nodeResponse.results);
+                totalShards += nodeResponse.getTotalShards();
+                successfulShards += nodeResponse.getSuccessfulShards();
+                for (BroadcastShardOperationFailedException throwable : nodeResponse.getExceptions()) {
                     if (TransportActions.isShardNotAvailableException(throwable) == false) {
                         exceptions.add(
                             new DefaultShardOperationFailedException(
@@ -256,16 +257,15 @@ public abstract class TransportBroadcastByNodeAction<
         new AsyncAction(task, request, listener).start();
     }
 
-    protected class AsyncAction {
+    protected class AsyncAction implements CancellableTask.CancellationListener {
         private final Task task;
         private final Request request;
         private final ActionListener<Response> listener;
         private final ClusterState clusterState;
         private final DiscoveryNodes nodes;
         private final Map<String, List<ShardRouting>> nodeIds;
-        private final AtomicReferenceArray<Object> responses;
-        private final AtomicInteger counter = new AtomicInteger();
         private final int unavailableShardCount;
+        private final NodeResponseTracker nodeResponseTracker;
 
         protected AsyncAction(Task task, Request request, ActionListener<Response> listener) {
             this.task = task;
@@ -312,10 +312,13 @@ public abstract class TransportBroadcastByNodeAction<
 
             }
             this.unavailableShardCount = unavailableShardCount;
-            responses = new AtomicReferenceArray<>(nodeIds.size());
+            nodeResponseTracker = new NodeResponseTracker(nodeIds.size());
         }
 
         public void start() {
+            if (task instanceof CancellableTask cancellableTask) {
+                cancellableTask.addListener(this);
+            }
             if (nodeIds.size() == 0) {
                 try {
                     onCompletion();
@@ -373,38 +376,34 @@ public abstract class TransportBroadcastByNodeAction<
                 logger.trace("received response for [{}] from node [{}]", actionName, node.getId());
             }
 
-            // this is defensive to protect against the possibility of double invocation
-            // the current implementation of TransportService#sendRequest guards against this
-            // but concurrency is hard, safety is important, and the small performance loss here does not matter
-            if (responses.compareAndSet(nodeIndex, null, response)) {
-                if (counter.incrementAndGet() == responses.length()) {
-                    onCompletion();
-                }
+            if (nodeResponseTracker.trackResponseAndCheckIfLast(nodeIndex, response)) {
+                onCompletion();
             }
         }
 
         protected void onNodeFailure(DiscoveryNode node, int nodeIndex, Throwable t) {
             String nodeId = node.getId();
             logger.debug(new ParameterizedMessage("failed to execute [{}] on node [{}]", actionName, nodeId), t);
-
-            // this is defensive to protect against the possibility of double invocation
-            // the current implementation of TransportService#sendRequest guards against this
-            // but concurrency is hard, safety is important, and the small performance loss here does not matter
-            if (responses.compareAndSet(nodeIndex, null, new FailedNodeException(nodeId, "Failed node [" + nodeId + "]", t))) {
-                if (counter.incrementAndGet() == responses.length()) {
-                    onCompletion();
-                }
+            if (nodeResponseTracker.trackResponseAndCheckIfLast(
+                nodeIndex,
+                new FailedNodeException(nodeId, "Failed node [" + nodeId + "]", t)
+            )) {
+                onCompletion();
             }
         }
 
         protected void onCompletion() {
-            if (task instanceof CancellableTask && ((CancellableTask) task).notifyIfCancelled(listener)) {
+            if ((task instanceof CancellableTask t) && t.notifyIfCancelled(listener)) {
                 return;
             }
 
             Response response = null;
             try {
-                response = newResponse(request, responses, unavailableShardCount, nodeIds, clusterState);
+                response = newResponse(request, nodeResponseTracker, unavailableShardCount, nodeIds, clusterState);
+            } catch (NodeResponseTracker.DiscardedResponsesException e) {
+                // We propagate the reason that the results, in this case the task cancellation, in case the listener needs to take
+                // follow-up actions
+                listener.onFailure((Exception) e.getCause());
             } catch (Exception e) {
                 logger.debug("failed to combine responses from nodes", e);
                 listener.onFailure(e);
@@ -417,6 +416,21 @@ public abstract class TransportBroadcastByNodeAction<
                 }
             }
         }
+
+        @Override
+        public void onCancelled() {
+            assert task instanceof CancellableTask : "task must be cancellable";
+            try {
+                ((CancellableTask) task).ensureNotCancelled();
+            } catch (TaskCancelledException e) {
+                nodeResponseTracker.discardIntermediateResponses(e);
+            }
+        }
+
+        // For testing purposes
+        public NodeResponseTracker getNodeResponseTracker() {
+            return nodeResponseTracker;
+        }
     }
 
     class BroadcastByNodeTransportRequestHandler implements TransportRequestHandler<NodeRequest> {

+ 44 - 21
server/src/main/java/org/elasticsearch/action/support/nodes/TransportNodesAction.java

@@ -13,6 +13,7 @@ import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.FailedNodeException;
 import org.elasticsearch.action.support.ActionFilters;
 import org.elasticsearch.action.support.HandledTransportAction;
+import org.elasticsearch.action.support.NodeResponseTracker;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.service.ClusterService;
@@ -20,6 +21,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.tasks.CancellableTask;
 import org.elasticsearch.tasks.Task;
+import org.elasticsearch.tasks.TaskCancelledException;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.TransportChannel;
 import org.elasticsearch.transport.TransportException;
@@ -34,8 +36,6 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
 import java.util.Objects;
-import java.util.concurrent.atomic.AtomicInteger;
-import java.util.concurrent.atomic.AtomicReferenceArray;
 
 public abstract class TransportNodesAction<
     NodesRequest extends BaseNodesRequest<NodesRequest>,
@@ -128,14 +128,15 @@ public abstract class TransportNodesAction<
      * pass it to the listener. Fails the listener with a {@link NullPointerException} if {@code nodesResponses} is null.
      *
      * @param request The associated request.
-     * @param nodesResponses All node-level responses
-     * @throws NullPointerException if {@code nodesResponses} is {@code null}
+     * @param nodeResponseTracker All node-level responses collected so far
+     * @throws NodeResponseTracker.DiscardedResponsesException if {@code nodeResponseTracker} has already discarded the intermediate results
      * @see #newResponseAsync(Task, BaseNodesRequest, List, List, ActionListener)
      */
     // exposed for tests
-    void newResponse(Task task, NodesRequest request, AtomicReferenceArray<?> nodesResponses, ActionListener<NodesResponse> listener) {
+    void newResponse(Task task, NodesRequest request, NodeResponseTracker nodeResponseTracker, ActionListener<NodesResponse> listener)
+        throws NodeResponseTracker.DiscardedResponsesException {
 
-        if (nodesResponses == null) {
+        if (nodeResponseTracker == null) {
             listener.onFailure(new NullPointerException("nodesResponses"));
             return;
         }
@@ -143,11 +144,10 @@ public abstract class TransportNodesAction<
         final List<NodeResponse> responses = new ArrayList<>();
         final List<FailedNodeException> failures = new ArrayList<>();
 
-        for (int i = 0; i < nodesResponses.length(); ++i) {
-            Object response = nodesResponses.get(i);
-
-            if (response instanceof FailedNodeException) {
-                failures.add((FailedNodeException) response);
+        for (int i = 0; i < nodeResponseTracker.getExpectedResponseCount(); ++i) {
+            Object response = nodeResponseTracker.getResponse(i);
+            if (nodeResponseTracker.getResponse(i)instanceof FailedNodeException failedNodeException) {
+                failures.add(failedNodeException);
             } else {
                 responses.add(nodeResponseClass.cast(response));
             }
@@ -203,12 +203,11 @@ public abstract class TransportNodesAction<
         return transportNodeAction;
     }
 
-    class AsyncAction {
+    class AsyncAction implements CancellableTask.CancellationListener {
 
         private final NodesRequest request;
         private final ActionListener<NodesResponse> listener;
-        private final AtomicReferenceArray<Object> responses;
-        private final AtomicInteger counter = new AtomicInteger();
+        private final NodeResponseTracker nodeResponseTracker;
         private final Task task;
 
         AsyncAction(Task task, NodesRequest request, ActionListener<NodesResponse> listener) {
@@ -219,10 +218,13 @@ public abstract class TransportNodesAction<
                 resolveRequest(request, clusterService.state());
                 assert request.concreteNodes() != null;
             }
-            this.responses = new AtomicReferenceArray<>(request.concreteNodes().length);
+            this.nodeResponseTracker = new NodeResponseTracker(request.concreteNodes().length);
         }
 
         void start() {
+            if (task instanceof CancellableTask cancellableTask) {
+                cancellableTask.addListener(this);
+            }
             final DiscoveryNode[] nodes = request.concreteNodes();
             if (nodes.length == 0) {
                 finishHim();
@@ -267,28 +269,49 @@ public abstract class TransportNodesAction<
             }
         }
 
+        // For testing purposes
+        NodeResponseTracker getNodeResponseTracker() {
+            return nodeResponseTracker;
+        }
+
         private void onOperation(int idx, NodeResponse nodeResponse) {
-            responses.set(idx, nodeResponse);
-            if (counter.incrementAndGet() == responses.length()) {
+            if (nodeResponseTracker.trackResponseAndCheckIfLast(idx, nodeResponse)) {
                 finishHim();
             }
         }
 
         private void onFailure(int idx, String nodeId, Throwable t) {
             logger.debug(new ParameterizedMessage("failed to execute on node [{}]", nodeId), t);
-            responses.set(idx, new FailedNodeException(nodeId, "Failed node [" + nodeId + "]", t));
-            if (counter.incrementAndGet() == responses.length()) {
+            if (nodeResponseTracker.trackResponseAndCheckIfLast(idx, new FailedNodeException(nodeId, "Failed node [" + nodeId + "]", t))) {
                 finishHim();
             }
         }
 
         private void finishHim() {
-            if (task instanceof CancellableTask && ((CancellableTask) task).notifyIfCancelled(listener)) {
+            if ((task instanceof CancellableTask t) && t.notifyIfCancelled(listener)) {
                 return;
             }
 
             final String executor = finalExecutor.equals(ThreadPool.Names.SAME) ? ThreadPool.Names.GENERIC : finalExecutor;
-            threadPool.executor(executor).execute(() -> newResponse(task, request, responses, listener));
+            threadPool.executor(executor).execute(() -> {
+                try {
+                    newResponse(task, request, nodeResponseTracker, listener);
+                } catch (NodeResponseTracker.DiscardedResponsesException e) {
+                    // We propagate the reason that the results, in this case the task cancellation, in case the listener needs to take
+                    // follow-up actions
+                    listener.onFailure((Exception) e.getCause());
+                }
+            });
+        }
+
+        @Override
+        public void onCancelled() {
+            assert task instanceof CancellableTask : "task must be cancellable";
+            try {
+                ((CancellableTask) task).ensureNotCancelled();
+            } catch (TaskCancelledException e) {
+                nodeResponseTracker.discardIntermediateResponses(e);
+            }
         }
     }
 

+ 24 - 0
server/src/main/java/org/elasticsearch/tasks/CancellableTask.java

@@ -12,6 +12,7 @@ import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.core.Nullable;
 
 import java.util.Map;
+import java.util.concurrent.ConcurrentLinkedQueue;
 
 /**
  * A task that can be cancelled
@@ -20,6 +21,7 @@ public class CancellableTask extends Task {
 
     private volatile String reason;
     private volatile boolean isCancelled;
+    private final ConcurrentLinkedQueue<CancellationListener> listeners = new ConcurrentLinkedQueue<>();
 
     public CancellableTask(long id, String type, String action, String description, TaskId parentTaskId, Map<String, String> headers) {
         super(id, type, action, description, parentTaskId, headers);
@@ -37,6 +39,7 @@ public class CancellableTask extends Task {
             this.isCancelled = true;
             this.reason = reason;
         }
+        listeners.forEach(CancellationListener::onCancelled);
         onCancelled();
     }
 
@@ -67,6 +70,20 @@ public class CancellableTask extends Task {
         return reason;
     }
 
+    /**
+     * This method adds a listener that needs to be notified if this task is cancelled.
+     */
+    public final void addListener(CancellationListener listener) {
+        synchronized (this) {
+            if (this.isCancelled == false) {
+                listeners.add(listener);
+            }
+        }
+        if (isCancelled) {
+            listener.onCancelled();
+        }
+    }
+
     /**
      * Called after the task is cancelled so that it can take any actions that it has to take.
      */
@@ -103,4 +120,11 @@ public class CancellableTask extends Task {
         assert reason != null;
         return new TaskCancelledException("task cancelled [" + reason + ']');
     }
+
+    /**
+     * This interface is implemented by any class that needs to react to the cancellation of this task.
+     */
+    public interface CancellationListener {
+        void onCancelled();
+    }
 }

+ 37 - 17
server/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/CancellableTasksTests.java

@@ -187,6 +187,19 @@ public class CancellableTasksTests extends TaskManagerTestCase {
         }
     }
 
+    /**
+     * Simulates a cancellation listener and sets a flag to true if the task was cancelled
+     */
+    static class CancellableTestCancellationListener implements CancellableTask.CancellationListener {
+
+        final AtomicBoolean calledUponCancellation = new AtomicBoolean(false);
+
+        @Override
+        public void onCancelled() {
+            calledUponCancellation.set(true);
+        }
+    }
+
     private Task startCancellableTestNodesAction(
         boolean waitForActionToStart,
         int runNodesCount,
@@ -252,6 +265,7 @@ public class CancellableTasksTests extends TaskManagerTestCase {
         setupTestNodes(Settings.EMPTY);
         connectNodes(testNodes);
         CountDownLatch responseLatch = new CountDownLatch(1);
+        AtomicBoolean listenerCalledUponCancellation = new AtomicBoolean(false);
         boolean waitForActionToStart = randomBoolean();
         logger.info("waitForActionToStart is set to {}", waitForActionToStart);
         final AtomicReference<NodesResponse> responseReference = new AtomicReference<>();
@@ -260,24 +274,23 @@ public class CancellableTasksTests extends TaskManagerTestCase {
         // Block at least 1 node, otherwise it's quite easy to end up in a race condition where the node tasks
         // have finished before the cancel request has arrived
         int blockedNodesCount = randomIntBetween(1, runNodesCount);
-        Task mainTask = startCancellableTestNodesAction(
-            waitForActionToStart,
-            runNodesCount,
-            blockedNodesCount,
-            new ActionListener<NodesResponse>() {
-                @Override
-                public void onResponse(NodesResponse listTasksResponse) {
-                    responseReference.set(listTasksResponse);
-                    responseLatch.countDown();
-                }
+        Task mainTask = startCancellableTestNodesAction(waitForActionToStart, runNodesCount, blockedNodesCount, new ActionListener<>() {
+            @Override
+            public void onResponse(NodesResponse listTasksResponse) {
+                responseReference.set(listTasksResponse);
+                responseLatch.countDown();
+            }
 
-                @Override
-                public void onFailure(Exception e) {
-                    throwableReference.set(e);
-                    responseLatch.countDown();
-                }
+            @Override
+            public void onFailure(Exception e) {
+                throwableReference.set(e);
+                responseLatch.countDown();
             }
-        );
+        });
+
+        assert mainTask instanceof CancellableTask;
+        CancellableTestCancellationListener listenerAddedBeforeCancellation = new CancellableTestCancellationListener();
+        ((CancellableTask) mainTask).addListener(listenerAddedBeforeCancellation);
 
         // Cancel main task
         CancelTasksRequest request = new CancelTasksRequest();
@@ -311,6 +324,13 @@ public class CancellableTasksTests extends TaskManagerTestCase {
             for (TaskInfo taskInfo : response.getTasks()) {
                 assertTrue(taskInfo.cancellable());
             }
+
+            CancellableTestCancellationListener listenerAddedAfterCancellation = new CancellableTestCancellationListener();
+            ((CancellableTask) mainTask).addListener(listenerAddedAfterCancellation);
+
+            // Verify both cancellation listeners have been notified
+            assertTrue(listenerAddedBeforeCancellation.calledUponCancellation.get());
+            assertTrue(listenerAddedAfterCancellation.calledUponCancellation.get());
         }
 
         // Make sure that tasks are no longer running
@@ -337,7 +357,7 @@ public class CancellableTasksTests extends TaskManagerTestCase {
         final AtomicReference<Throwable> throwableReference = new AtomicReference<>();
         int runNodesCount = randomIntBetween(1, nodesCount);
         int blockedNodesCount = randomIntBetween(0, runNodesCount);
-        Task mainTask = startCancellableTestNodesAction(true, runNodesCount, blockedNodesCount, new ActionListener<NodesResponse>() {
+        Task mainTask = startCancellableTestNodesAction(true, runNodesCount, blockedNodesCount, new ActionListener<>() {
             @Override
             public void onResponse(NodesResponse listTasksResponse) {
                 responseReference.set(listTasksResponse);

+ 61 - 0
server/src/test/java/org/elasticsearch/action/support/NodeResponseTrackerTests.java

@@ -0,0 +1,61 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.action.support;
+
+import org.elasticsearch.test.ESTestCase;
+
+public class NodeResponseTrackerTests extends ESTestCase {
+
+    public void testAllResponsesReceived() throws Exception {
+        int nodes = randomIntBetween(1, 10);
+        NodeResponseTracker intermediateNodeResponses = new NodeResponseTracker(nodes);
+        for (int i = 0; i < nodes; i++) {
+            boolean isLast = i == nodes - 1;
+            assertEquals(
+                isLast,
+                intermediateNodeResponses.trackResponseAndCheckIfLast(i, randomBoolean() ? i : new Exception("from node " + i))
+            );
+        }
+
+        assertFalse(intermediateNodeResponses.responsesDiscarded());
+        assertEquals(nodes, intermediateNodeResponses.getExpectedResponseCount());
+        for (int i = 0; i < nodes; i++) {
+            assertNotNull(intermediateNodeResponses.getResponse(i));
+            if (intermediateNodeResponses.getResponse(i)instanceof Integer nodeResponse) {
+                assertEquals(i, nodeResponse.intValue());
+            }
+        }
+    }
+
+    public void testDiscardingResults() {
+        int nodes = randomIntBetween(1, 10);
+        int cancelAt = randomIntBetween(0, Math.max(0, nodes - 2));
+        NodeResponseTracker intermediateNodeResponses = new NodeResponseTracker(nodes);
+        for (int i = 0; i < nodes; i++) {
+            if (i == cancelAt) {
+                intermediateNodeResponses.discardIntermediateResponses(new Exception("simulated"));
+            }
+            boolean isLast = i == nodes - 1;
+            assertEquals(
+                isLast,
+                intermediateNodeResponses.trackResponseAndCheckIfLast(i, randomBoolean() ? i : new Exception("from node " + i))
+            );
+        }
+
+        assertTrue(intermediateNodeResponses.responsesDiscarded());
+        assertEquals(nodes, intermediateNodeResponses.getExpectedResponseCount());
+        expectThrows(NodeResponseTracker.DiscardedResponsesException.class, () -> intermediateNodeResponses.getResponse(0));
+    }
+
+    public void testResponseIsRegisteredOnlyOnce() {
+        NodeResponseTracker intermediateNodeResponses = new NodeResponseTracker(1);
+        assertTrue(intermediateNodeResponses.trackResponseAndCheckIfLast(0, "response1"));
+        expectThrows(AssertionError.class, () -> intermediateNodeResponses.trackResponseAndCheckIfLast(0, "response2"));
+    }
+}

+ 11 - 2
server/src/test/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeActionTests.java

@@ -537,14 +537,23 @@ public class TransportBroadcastByNodeActionTests extends ESTestCase {
     public void testNoResultAggregationIfTaskCancelled() {
         Request request = new Request(new String[] { TEST_INDEX });
         PlainActionFuture<Response> listener = new PlainActionFuture<>();
-        action.new AsyncAction(cancelledTask(), request, listener).start();
+        final CancellableTask task = new CancellableTask(randomLong(), "transport", "action", "", null, emptyMap());
+        TransportBroadcastByNodeAction<Request, Response, TransportBroadcastByNodeAction.EmptyResult>.AsyncAction asyncAction =
+            action.new AsyncAction(task, request, listener);
+        asyncAction.start();
         Map<String, List<CapturingTransport.CapturedRequest>> capturedRequests = transport.getCapturedRequestsByTargetNodeAndClear();
-
+        int cancelAt = randomIntBetween(0, Math.max(0, capturedRequests.size() - 2));
+        int i = 0;
         for (Map.Entry<String, List<CapturingTransport.CapturedRequest>> entry : capturedRequests.entrySet()) {
+            if (cancelAt == i) {
+                TaskCancelHelper.cancel(task, "simulated");
+            }
             transport.handleRemoteError(entry.getValue().get(0).requestId(), new ElasticsearchException("simulated"));
+            i++;
         }
 
         assertTrue(listener.isDone());
+        assertTrue(asyncAction.getNodeResponseTracker().responsesDiscarded());
         expectThrows(ExecutionException.class, TaskCancelledException.class, listener::get);
     }
 

+ 16 - 8
server/src/test/java/org/elasticsearch/action/support/nodes/TransportNodesActionTests.java

@@ -11,6 +11,7 @@ package org.elasticsearch.action.support.nodes;
 import org.elasticsearch.Version;
 import org.elasticsearch.action.FailedNodeException;
 import org.elasticsearch.action.support.ActionFilters;
+import org.elasticsearch.action.support.NodeResponseTracker;
 import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.action.support.broadcast.node.TransportBroadcastByNodeActionTests;
 import org.elasticsearch.cluster.ClusterName;
@@ -47,7 +48,6 @@ import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.TimeUnit;
-import java.util.concurrent.atomic.AtomicReferenceArray;
 import java.util.function.Supplier;
 
 import static java.util.Collections.emptyMap;
@@ -94,14 +94,14 @@ public class TransportNodesActionTests extends ESTestCase {
         assertEquals(clusterService.state().nodes().resolveNodes(finalNodesIds).length, capturedRequests.size());
     }
 
-    public void testNewResponseNullArray() {
+    public void testNewResponseNullArray() throws Exception {
         TransportNodesAction<TestNodesRequest, TestNodesResponse, TestNodeRequest, TestNodeResponse> action = getTestTransportNodesAction();
         final PlainActionFuture<TestNodesResponse> future = new PlainActionFuture<>();
         action.newResponse(new Task(1, "test", "test", "", null, emptyMap()), new TestNodesRequest(), null, future);
         expectThrows(NullPointerException.class, future::actionGet);
     }
 
-    public void testNewResponse() {
+    public void testNewResponse() throws Exception {
         TestTransportNodesAction action = getTestTransportNodesAction();
         TestNodesRequest request = new TestNodesRequest();
         List<TestNodeResponse> expectedNodeResponses = mockList(TestNodeResponse::new, randomIntBetween(0, 2));
@@ -120,10 +120,10 @@ public class TransportNodesActionTests extends ESTestCase {
 
         Collections.shuffle(allResponses, random());
 
-        AtomicReferenceArray<?> atomicArray = new AtomicReferenceArray<>(allResponses.toArray());
+        NodeResponseTracker nodeResponseCollector = new NodeResponseTracker(allResponses);
 
         final PlainActionFuture<TestNodesResponse> future = new PlainActionFuture<>();
-        action.newResponse(new Task(1, "test", "test", "", null, emptyMap()), request, atomicArray, future);
+        action.newResponse(new Task(1, "test", "test", "", null, emptyMap()), request, nodeResponseCollector, future);
         TestNodesResponse response = future.actionGet();
 
         assertSame(request, response.request);
@@ -146,7 +146,7 @@ public class TransportNodesActionTests extends ESTestCase {
         assertEquals(clusterService.state().nodes().getDataNodes().size(), capturedRequests.size());
     }
 
-    public void testTaskCancellationThrowsException() {
+    public void testTaskCancellation() {
         TransportNodesAction<TestNodesRequest, TestNodesResponse, TestNodeRequest, TestNodeResponse> action = getTestTransportNodesAction();
         List<String> nodeIds = new ArrayList<>();
         for (DiscoveryNode node : clusterService.state().nodes()) {
@@ -156,10 +156,16 @@ public class TransportNodesActionTests extends ESTestCase {
         TestNodesRequest request = new TestNodesRequest(nodeIds.toArray(new String[0]));
         PlainActionFuture<TestNodesResponse> listener = new PlainActionFuture<>();
         CancellableTask cancellableTask = new CancellableTask(randomLong(), "transport", "action", "", null, emptyMap());
-        TaskCancelHelper.cancel(cancellableTask, "simulated");
-        action.doExecute(cancellableTask, request, listener);
+        TransportNodesAction<TestNodesRequest, TestNodesResponse, TestNodeRequest, TestNodeResponse>.AsyncAction asyncAction =
+            action.new AsyncAction(cancellableTask, request, listener);
+        asyncAction.start();
         Map<String, List<CapturingTransport.CapturedRequest>> capturedRequests = transport.getCapturedRequestsByTargetNodeAndClear();
+        int cancelAt = randomIntBetween(0, Math.max(0, capturedRequests.values().size() - 2));
+        int requestCount = 0;
         for (List<CapturingTransport.CapturedRequest> requests : capturedRequests.values()) {
+            if (requestCount == cancelAt) {
+                TaskCancelHelper.cancel(cancellableTask, "simulated");
+            }
             for (CapturingTransport.CapturedRequest capturedRequest : requests) {
                 if (randomBoolean()) {
                     transport.handleResponse(capturedRequest.requestId(), new TestNodeResponse(capturedRequest.node()));
@@ -167,9 +173,11 @@ public class TransportNodesActionTests extends ESTestCase {
                     transport.handleRemoteError(capturedRequest.requestId(), new TaskCancelledException("simulated"));
                 }
             }
+            requestCount++;
         }
 
         assertTrue(listener.isDone());
+        assertTrue(asyncAction.getNodeResponseTracker().responsesDiscarded());
         expectThrows(ExecutionException.class, TaskCancelledException.class, listener::get);
     }