Kaynağa Gözat

Improve cancellability in TransportTasksAction (#96279)

Each `TransportTasksAction` fans-out to multiple nodes, accumulates
responses and retains them until all the nodes have responded, and then
converts the responses into a final result.

Similarly to #92987 and #93484, we should accumulate the responses in a
structure that doesn't require so much copying later on, and should drop
the received responses if the task is cancelled while some nodes'
responses are still pending.
David Turner 2 yıl önce
ebeveyn
işleme
2513104dbb

+ 5 - 0
docs/changelog/96279.yaml

@@ -0,0 +1,5 @@
+pr: 96279
+summary: Improve cancellability in `TransportTasksAction`
+area: Task Management
+type: bug
+issues: []

+ 107 - 189
server/src/main/java/org/elasticsearch/action/support/tasks/TransportTasksAction.java

@@ -10,40 +10,36 @@ package org.elasticsearch.action.support.tasks;
 
 import org.elasticsearch.ResourceNotFoundException;
 import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.ActionListenerResponseHandler;
 import org.elasticsearch.action.FailedNodeException;
 import org.elasticsearch.action.NoSuchNodeException;
 import org.elasticsearch.action.TaskOperationFailure;
 import org.elasticsearch.action.support.ActionFilters;
+import org.elasticsearch.action.support.CancellableFanOut;
 import org.elasticsearch.action.support.ChannelActionListener;
 import org.elasticsearch.action.support.HandledTransportAction;
-import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNodes;
 import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.collect.Iterators;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
-import org.elasticsearch.common.util.concurrent.AtomicArray;
-import org.elasticsearch.core.Tuple;
 import org.elasticsearch.tasks.CancellableTask;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.tasks.TaskId;
 import org.elasticsearch.transport.TransportChannel;
-import org.elasticsearch.transport.TransportException;
 import org.elasticsearch.transport.TransportRequest;
 import org.elasticsearch.transport.TransportRequestHandler;
 import org.elasticsearch.transport.TransportRequestOptions;
 import org.elasticsearch.transport.TransportResponse;
-import org.elasticsearch.transport.TransportResponseHandler;
 import org.elasticsearch.transport.TransportService;
 
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.List;
 import java.util.Map;
-import java.util.concurrent.atomic.AtomicInteger;
-import java.util.concurrent.atomic.AtomicReferenceArray;
-
-import static java.util.Collections.emptyList;
 
 /**
  * The base class for transport actions that are interacting with currently running tasks.
@@ -85,67 +81,113 @@ public abstract class TransportTasksAction<
 
     @Override
     protected void doExecute(Task task, TasksRequest request, ActionListener<TasksResponse> listener) {
-        new AsyncAction(task, request, listener).start();
-    }
+        final var discoveryNodes = clusterService.state().nodes();
+        final String[] nodeIds = resolveNodes(request, discoveryNodes);
+
+        new CancellableFanOut<String, NodeTasksResponse, TasksResponse>() {
+            final ArrayList<TaskResponse> taskResponses = new ArrayList<>();
+            final ArrayList<TaskOperationFailure> taskOperationFailures = new ArrayList<>();
+            final ArrayList<FailedNodeException> failedNodeExceptions = new ArrayList<>();
+            final TransportRequestOptions transportRequestOptions = TransportRequestOptions.timeout(request.getTimeout());
+
+            @Override
+            protected void sendItemRequest(String nodeId, ActionListener<NodeTasksResponse> listener) {
+                final var discoveryNode = discoveryNodes.get(nodeId);
+                if (discoveryNode == null) {
+                    listener.onFailure(new NoSuchNodeException(nodeId));
+                    return;
+                }
+
+                transportService.sendChildRequest(
+                    discoveryNode,
+                    transportNodeAction,
+                    new NodeTaskRequest(request),
+                    task,
+                    transportRequestOptions,
+                    new ActionListenerResponseHandler<>(listener, nodeResponseReader)
+                );
+            }
+
+            @Override
+            protected void onItemResponse(String nodeId, NodeTasksResponse nodeTasksResponse) {
+                addAllSynchronized(taskResponses, nodeTasksResponse.results);
+                addAllSynchronized(taskOperationFailures, nodeTasksResponse.exceptions);
+            }
+
+            @SuppressWarnings("SynchronizationOnLocalVariableOrMethodParameter")
+            private static <T> void addAllSynchronized(List<T> allResults, Collection<T> response) {
+                if (response.isEmpty() == false) {
+                    synchronized (allResults) {
+                        allResults.addAll(response);
+                    }
+                }
+            }
+
+            @Override
+            protected void onItemFailure(String nodeId, Exception e) {
+                logger.debug(() -> Strings.format("failed to execute on node [{}]", nodeId), e);
+                synchronized (failedNodeExceptions) {
+                    failedNodeExceptions.add(new FailedNodeException(nodeId, "Failed node [" + nodeId + "]", e));
+                }
+            }
+
+            @Override
+            protected TasksResponse onCompletion() {
+                // ref releases all happen-before here so no need to be synchronized
+                return newResponse(request, taskResponses, taskOperationFailures, failedNodeExceptions);
+            }
 
-    private void nodeOperation(CancellableTask task, NodeTaskRequest nodeTaskRequest, ActionListener<NodeTasksResponse> listener) {
-        TasksRequest request = nodeTaskRequest.tasksRequest;
-        processTasks(request, ActionListener.wrap(tasks -> nodeOperation(task, listener, request, tasks), listener::onFailure));
+            @Override
+            public String toString() {
+                return actionName;
+            }
+        }.run(task, Iterators.forArray(nodeIds), listener);
     }
 
+    // not an inline method reference to avoid capturing CancellableFanOut.this.
+    private final Writeable.Reader<NodeTasksResponse> nodeResponseReader = NodeTasksResponse::new;
+
     private void nodeOperation(
-        CancellableTask task,
+        CancellableTask nodeTask,
         ActionListener<NodeTasksResponse> listener,
         TasksRequest request,
-        List<OperationTask> tasks
+        List<OperationTask> operationTasks
     ) {
-        if (tasks.isEmpty()) {
-            listener.onResponse(new NodeTasksResponse(clusterService.localNode().getId(), emptyList(), emptyList()));
-            return;
-        }
-        AtomicArray<Tuple<TaskResponse, Exception>> responses = new AtomicArray<>(tasks.size());
-        final AtomicInteger counter = new AtomicInteger(tasks.size());
-        for (int i = 0; i < tasks.size(); i++) {
-            final int taskIndex = i;
-            ActionListener<TaskResponse> taskListener = new ActionListener<TaskResponse>() {
-                @Override
-                public void onResponse(TaskResponse response) {
-                    responses.setOnce(taskIndex, response == null ? null : new Tuple<>(response, null));
-                    respondIfFinished();
-                }
+        new CancellableFanOut<OperationTask, TaskResponse, NodeTasksResponse>() {
 
-                @Override
-                public void onFailure(Exception e) {
-                    responses.setOnce(taskIndex, new Tuple<>(null, e));
-                    respondIfFinished();
+            final ArrayList<TaskResponse> results = new ArrayList<>(operationTasks.size());
+            final ArrayList<TaskOperationFailure> exceptions = new ArrayList<>();
+
+            @Override
+            protected void sendItemRequest(OperationTask operationTask, ActionListener<TaskResponse> listener) {
+                ActionListener.run(listener, l -> taskOperation(nodeTask, request, operationTask, l));
+            }
+
+            @Override
+            protected void onItemResponse(OperationTask operationTask, TaskResponse taskResponse) {
+                synchronized (results) {
+                    results.add(taskResponse);
                 }
+            }
 
-                private void respondIfFinished() {
-                    if (counter.decrementAndGet() != 0) {
-                        return;
-                    }
-                    List<TaskResponse> results = new ArrayList<>();
-                    List<TaskOperationFailure> exceptions = new ArrayList<>();
-                    for (Tuple<TaskResponse, Exception> response : responses.asList()) {
-                        if (response.v1() == null) {
-                            assert response.v2() != null;
-                            exceptions.add(
-                                new TaskOperationFailure(clusterService.localNode().getId(), tasks.get(taskIndex).getId(), response.v2())
-                            );
-                        } else {
-                            assert response.v2() == null;
-                            results.add(response.v1());
-                        }
-                    }
-                    listener.onResponse(new NodeTasksResponse(clusterService.localNode().getId(), results, exceptions));
+            @Override
+            protected void onItemFailure(OperationTask operationTask, Exception e) {
+                synchronized (exceptions) {
+                    exceptions.add(new TaskOperationFailure(clusterService.localNode().getId(), operationTask.getId(), e));
                 }
-            };
-            try {
-                taskOperation(task, request, tasks.get(taskIndex), taskListener);
-            } catch (Exception e) {
-                taskListener.onFailure(e);
             }
-        }
+
+            @Override
+            protected NodeTasksResponse onCompletion() {
+                // ref releases all happen-before here so no need to be synchronized
+                return new NodeTasksResponse(clusterService.localNode().getId(), results, exceptions);
+            }
+
+            @Override
+            public String toString() {
+                return transportNodeAction;
+            }
+        }.run(nodeTask, operationTasks.iterator(), listener);
     }
 
     protected String[] resolveNodes(TasksRequest request, DiscoveryNodes discoveryNodes) {
@@ -192,28 +234,6 @@ public abstract class TransportTasksAction<
         List<FailedNodeException> failedNodeExceptions
     );
 
-    @SuppressWarnings("unchecked")
-    protected TasksResponse newResponse(TasksRequest request, AtomicReferenceArray<?> responses) {
-        List<TaskResponse> tasks = new ArrayList<>();
-        List<FailedNodeException> failedNodeExceptions = new ArrayList<>();
-        List<TaskOperationFailure> taskOperationFailures = new ArrayList<>();
-        for (int i = 0; i < responses.length(); i++) {
-            Object response = responses.get(i);
-            if (response instanceof FailedNodeException) {
-                failedNodeExceptions.add((FailedNodeException) response);
-            } else {
-                NodeTasksResponse tasksResponse = (NodeTasksResponse) response;
-                if (tasksResponse.results != null) {
-                    tasks.addAll(tasksResponse.results);
-                }
-                if (tasksResponse.exceptions != null) {
-                    taskOperationFailures.addAll(tasksResponse.exceptions);
-                }
-            }
-        }
-        return newResponse(request, tasks, taskOperationFailures, failedNodeExceptions);
-    }
-
     /**
      * Perform the required operation on the task. It is OK start an asynchronous operation or to throw an exception but not both.
      * @param actionTask The related transport action task. Can be used to create a task ID to handle upstream transport cancellations.
@@ -228,120 +248,18 @@ public abstract class TransportTasksAction<
         ActionListener<TaskResponse> listener
     );
 
-    private class AsyncAction {
-
-        private final TasksRequest request;
-        private final String[] nodesIds;
-        private final DiscoveryNode[] nodes;
-        private final ActionListener<TasksResponse> listener;
-        private final AtomicReferenceArray<Object> responses;
-        private final AtomicInteger counter = new AtomicInteger();
-        private final Task task;
-
-        private AsyncAction(Task task, TasksRequest request, ActionListener<TasksResponse> listener) {
-            this.task = task;
-            this.request = request;
-            this.listener = listener;
-            final DiscoveryNodes discoveryNodes = clusterService.state().nodes();
-            this.nodesIds = resolveNodes(request, discoveryNodes);
-            Map<String, DiscoveryNode> nodes = discoveryNodes.getNodes();
-            this.nodes = new DiscoveryNode[nodesIds.length];
-            for (int i = 0; i < this.nodesIds.length; i++) {
-                this.nodes[i] = nodes.get(this.nodesIds[i]);
-            }
-            this.responses = new AtomicReferenceArray<>(this.nodesIds.length);
-        }
-
-        private void start() {
-            if (nodesIds.length == 0) {
-                // nothing to do
-                try {
-                    listener.onResponse(newResponse(request, responses));
-                } catch (Exception e) {
-                    logger.debug("failed to generate empty response", e);
-                    listener.onFailure(e);
-                }
-            } else {
-                final TransportRequestOptions transportRequestOptions = TransportRequestOptions.timeout(request.getTimeout());
-                for (int i = 0; i < nodesIds.length; i++) {
-                    final String nodeId = nodesIds[i];
-                    final int idx = i;
-                    final DiscoveryNode node = nodes[i];
-                    try {
-                        if (node == null) {
-                            onFailure(idx, nodeId, new NoSuchNodeException(nodeId));
-                        } else {
-                            NodeTaskRequest nodeRequest = new NodeTaskRequest(request);
-                            nodeRequest.setParentTask(clusterService.localNode().getId(), task.getId());
-                            transportService.sendRequest(
-                                node,
-                                transportNodeAction,
-                                nodeRequest,
-                                transportRequestOptions,
-                                new TransportResponseHandler<NodeTasksResponse>() {
-                                    @Override
-                                    public NodeTasksResponse read(StreamInput in) throws IOException {
-                                        return new NodeTasksResponse(in);
-                                    }
-
-                                    @Override
-                                    public void handleResponse(NodeTasksResponse response) {
-                                        onOperation(idx, response);
-                                    }
-
-                                    @Override
-                                    public void handleException(TransportException exp) {
-                                        onFailure(idx, node.getId(), exp);
-                                    }
-                                }
-                            );
-                        }
-                    } catch (Exception e) {
-                        onFailure(idx, nodeId, e);
-                    }
-                }
-            }
-        }
-
-        private void onOperation(int idx, NodeTasksResponse nodeResponse) {
-            responses.set(idx, nodeResponse);
-            if (counter.incrementAndGet() == responses.length()) {
-                finishHim();
-            }
-        }
-
-        private void onFailure(int idx, String nodeId, Throwable t) {
-            logger.debug(() -> "failed to execute on node [" + nodeId + "]", t);
-
-            responses.set(idx, new FailedNodeException(nodeId, "Failed node [" + nodeId + "]", t));
-
-            if (counter.incrementAndGet() == responses.length()) {
-                finishHim();
-            }
-        }
-
-        private void finishHim() {
-            if ((task instanceof CancellableTask t) && t.notifyIfCancelled(listener)) {
-                return;
-            }
-            TasksResponse finalResponse;
-            try {
-                finalResponse = newResponse(request, responses);
-            } catch (Exception e) {
-                logger.debug("failed to combine responses from nodes", e);
-                listener.onFailure(e);
-                return;
-            }
-            listener.onResponse(finalResponse);
-        }
-    }
-
     class NodeTransportHandler implements TransportRequestHandler<NodeTaskRequest> {
 
         @Override
         public void messageReceived(final NodeTaskRequest request, final TransportChannel channel, Task task) throws Exception {
             assert task instanceof CancellableTask;
-            nodeOperation((CancellableTask) task, request, new ChannelActionListener<>(channel));
+            TasksRequest tasksRequest = request.tasksRequest;
+            processTasks(
+                tasksRequest,
+                new ChannelActionListener<NodeTasksResponse>(channel).delegateFailure(
+                    (l, tasks) -> nodeOperation((CancellableTask) task, l, tasksRequest, tasks)
+                )
+            );
         }
     }
 

+ 152 - 0
server/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TransportTasksActionTests.java

@@ -7,6 +7,7 @@
  */
 package org.elasticsearch.action.admin.cluster.node.tasks;
 
+import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.action.ActionFuture;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.FailedNodeException;
@@ -40,6 +41,7 @@ import org.elasticsearch.tasks.Task;
 import org.elasticsearch.tasks.TaskCancelledException;
 import org.elasticsearch.tasks.TaskId;
 import org.elasticsearch.tasks.TaskInfo;
+import org.elasticsearch.test.ReachabilityChecker;
 import org.elasticsearch.test.tasks.MockTaskManager;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.TransportRequest;
@@ -55,9 +57,12 @@ import java.util.Collections;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
 import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.LinkedBlockingQueue;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicReference;
+import java.util.concurrent.atomic.AtomicReferenceArray;
 import java.util.stream.Collectors;
 
 import static org.elasticsearch.action.support.PlainActionFuture.newFuture;
@@ -68,6 +73,7 @@ import static org.hamcrest.Matchers.greaterThanOrEqualTo;
 import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.not;
+import static org.hamcrest.Matchers.nullValue;
 
 public class TransportTasksActionTests extends TaskManagerTestCase {
 
@@ -674,6 +680,152 @@ public class TransportTasksActionTests extends TaskManagerTestCase {
         assertEquals(0, responses.failureCount());
     }
 
+    public void testTaskResponsesDiscardedOnCancellation() throws Exception {
+        setupTestNodes(Settings.EMPTY);
+        connectNodes(testNodes);
+        CountDownLatch blockedActionLatch = new CountDownLatch(1);
+        ActionFuture<NodesResponse> future = startBlockingTestNodesAction(blockedActionLatch);
+
+        final var taskResponseListeners = new LinkedBlockingQueue<ActionListener<TestTaskResponse>>();
+        final var taskResponseListenersCountDown = new CountDownLatch(2); // test action plus the list[n] action
+
+        final TestTasksAction tasksAction = new TestTasksAction(
+            "internal:testTasksAction",
+            testNodes[0].clusterService,
+            testNodes[0].transportService
+        ) {
+            @Override
+            protected void taskOperation(
+                CancellableTask actionTask,
+                TestTasksRequest request,
+                Task task,
+                ActionListener<TestTaskResponse> listener
+            ) {
+                taskResponseListeners.add(listener);
+                taskResponseListenersCountDown.countDown();
+            }
+        };
+
+        TestTasksRequest testTasksRequest = new TestTasksRequest();
+        testTasksRequest.setNodes(testNodes[0].getNodeId()); // only local node
+        PlainActionFuture<TestTasksResponse> taskFuture = newFuture();
+        CancellableTask task = (CancellableTask) testNodes[0].transportService.getTaskManager()
+            .registerAndExecute(
+                "direct",
+                tasksAction,
+                testTasksRequest,
+                testNodes[0].transportService.getLocalNodeConnection(),
+                taskFuture
+            );
+        safeAwait(taskResponseListenersCountDown);
+
+        final var reachabilityChecker = new ReachabilityChecker();
+
+        final var listener0 = Objects.requireNonNull(taskResponseListeners.poll());
+        if (randomBoolean()) {
+            listener0.onResponse(reachabilityChecker.register(new TestTaskResponse("status")));
+        } else {
+            listener0.onFailure(reachabilityChecker.register(new ElasticsearchException("simulated")));
+        }
+        reachabilityChecker.checkReachable();
+
+        PlainActionFuture.<Void, RuntimeException>get(
+            fut -> testNodes[0].transportService.getTaskManager().cancelTaskAndDescendants(task, "test", false, fut),
+            10,
+            TimeUnit.SECONDS
+        );
+
+        reachabilityChecker.ensureUnreachable();
+
+        while (true) {
+            final var listener = taskResponseListeners.poll();
+            if (listener == null) {
+                break;
+            }
+            if (randomBoolean()) {
+                listener.onResponse(reachabilityChecker.register(new TestTaskResponse("status")));
+            } else {
+                listener.onFailure(reachabilityChecker.register(new ElasticsearchException("simulated")));
+            }
+            reachabilityChecker.ensureUnreachable();
+        }
+
+        expectThrows(TaskCancelledException.class, taskFuture::actionGet);
+
+        blockedActionLatch.countDown();
+        NodesResponse responses = future.get(10, TimeUnit.SECONDS);
+        assertEquals(0, responses.failureCount());
+    }
+
+    public void testNodeResponsesDiscardedOnCancellation() {
+        setupTestNodes(Settings.EMPTY);
+        connectNodes(testNodes);
+
+        final var taskResponseListeners = new AtomicReferenceArray<ActionListener<TestTaskResponse>>(testNodes.length);
+        final var taskResponseListenersCountDown = new CountDownLatch(testNodes.length); // one list[n] action per node
+        final var tasksActions = new TestTasksAction[testNodes.length];
+        for (int i = 0; i < testNodes.length; i++) {
+            final var nodeIndex = i;
+            tasksActions[i] = new TestTasksAction("internal:testTasksAction", testNodes[i].clusterService, testNodes[i].transportService) {
+                @Override
+                protected void taskOperation(
+                    CancellableTask actionTask,
+                    TestTasksRequest request,
+                    Task task,
+                    ActionListener<TestTaskResponse> listener
+                ) {
+                    assertThat(taskResponseListeners.getAndSet(nodeIndex, ActionListener.notifyOnce(listener)), nullValue());
+                    taskResponseListenersCountDown.countDown();
+                }
+            };
+        }
+
+        TestTasksRequest testTasksRequest = new TestTasksRequest();
+        testTasksRequest.setActions("internal:testTasksAction[n]");
+        PlainActionFuture<TestTasksResponse> taskFuture = newFuture();
+        CancellableTask task = (CancellableTask) testNodes[0].transportService.getTaskManager()
+            .registerAndExecute(
+                "direct",
+                tasksActions[0],
+                testTasksRequest,
+                testNodes[0].transportService.getLocalNodeConnection(),
+                taskFuture
+            );
+        safeAwait(taskResponseListenersCountDown);
+
+        final var reachabilityChecker = new ReachabilityChecker();
+
+        if (randomBoolean()) {
+            // local node does not de/serialize node-level response so retains references to the task-level response
+            if (randomBoolean()) {
+                taskResponseListeners.get(0).onResponse(reachabilityChecker.register(new TestTaskResponse("status")));
+            } else {
+                taskResponseListeners.get(0).onFailure(reachabilityChecker.register(new ElasticsearchException("simulated")));
+            }
+            reachabilityChecker.checkReachable();
+        }
+
+        PlainActionFuture.<Void, RuntimeException>get(
+            fut -> testNodes[0].transportService.getTaskManager().cancelTaskAndDescendants(task, "test", false, fut),
+            10,
+            TimeUnit.SECONDS
+        );
+
+        reachabilityChecker.ensureUnreachable();
+        assertFalse(taskFuture.isDone());
+
+        for (int i = 0; i < testNodes.length; i++) {
+            if (randomBoolean()) {
+                taskResponseListeners.get(i).onResponse(reachabilityChecker.register(new TestTaskResponse("status")));
+            } else {
+                taskResponseListeners.get(i).onFailure(reachabilityChecker.register(new ElasticsearchException("simulated")));
+            }
+            reachabilityChecker.ensureUnreachable();
+        }
+
+        expectThrows(TaskCancelledException.class, taskFuture::actionGet);
+    }
+
     public void testTaskLevelActionFailures() throws Exception {
         setupTestNodes(Settings.EMPTY);
         connectNodes(testNodes);