Browse Source

Make list-tasks API cancellable (#96283)

In a busy cluster the list-tasks API may retain information about a very
large number of tasks while waiting for all nodes to respond. This
commit makes the API cancellable so that unnecessary partial results can
be released earlier.

Relates #96279, which implements the early-release functionality.
David Turner 2 years ago
parent
commit
e5111e388a

+ 72 - 0
qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/action/support/tasks/RestListTasksCancellationIT.java

@@ -0,0 +1,72 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.action.support.tasks;
+
+import org.apache.http.client.methods.HttpGet;
+import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksAction;
+import org.elasticsearch.action.admin.cluster.state.ClusterStateAction;
+import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.client.Cancellable;
+import org.elasticsearch.client.Request;
+import org.elasticsearch.client.Response;
+import org.elasticsearch.http.HttpSmokeTestCase;
+import org.elasticsearch.tasks.TaskManager;
+import org.elasticsearch.transport.TransportService;
+
+import java.util.ArrayList;
+import java.util.concurrent.CancellationException;
+import java.util.concurrent.TimeUnit;
+
+import static org.elasticsearch.action.support.ActionTestUtils.wrapAsRestResponseListener;
+import static org.elasticsearch.test.TaskAssertions.awaitTaskWithPrefix;
+
+public class RestListTasksCancellationIT extends HttpSmokeTestCase {
+
+    public void testListTasksCancellation() throws Exception {
+        final Request clusterStateRequest = new Request(HttpGet.METHOD_NAME, "/_cluster/state");
+        clusterStateRequest.addParameter("wait_for_metadata_version", Long.toString(Long.MAX_VALUE));
+        clusterStateRequest.addParameter("wait_for_timeout", "1h");
+
+        final PlainActionFuture<Response> clusterStateFuture = new PlainActionFuture<>();
+        final Cancellable clusterStateCancellable = getRestClient().performRequestAsync(
+            clusterStateRequest,
+            wrapAsRestResponseListener(clusterStateFuture)
+        );
+
+        awaitTaskWithPrefix(ClusterStateAction.NAME);
+
+        final Request tasksRequest = new Request(HttpGet.METHOD_NAME, "/_tasks");
+        tasksRequest.addParameter("actions", ClusterStateAction.NAME);
+        tasksRequest.addParameter("wait_for_completion", Boolean.toString(true));
+        tasksRequest.addParameter("timeout", "1h");
+
+        final PlainActionFuture<Response> tasksFuture = new PlainActionFuture<>();
+        final Cancellable tasksCancellable = getRestClient().performRequestAsync(tasksRequest, wrapAsRestResponseListener(tasksFuture));
+
+        awaitTaskWithPrefix(ListTasksAction.NAME + "[n]");
+
+        tasksCancellable.cancel();
+
+        final var taskManagers = new ArrayList<TaskManager>(internalCluster().getNodeNames().length);
+        for (final var transportService : internalCluster().getInstances(TransportService.class)) {
+            taskManagers.add(transportService.getTaskManager());
+        }
+        assertBusy(
+            () -> assertFalse(
+                taskManagers.stream()
+                    .flatMap(taskManager -> taskManager.getCancellableTasks().values().stream())
+                    .anyMatch(t -> t.getAction().startsWith(ListTasksAction.NAME))
+            )
+        );
+
+        expectThrows(CancellationException.class, () -> tasksFuture.actionGet(10, TimeUnit.SECONDS));
+        clusterStateCancellable.cancel();
+    }
+
+}

+ 7 - 0
server/src/main/java/org/elasticsearch/action/admin/cluster/node/tasks/list/ListTasksRequest.java

@@ -14,9 +14,12 @@ import org.elasticsearch.action.support.tasks.BaseTasksRequest;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.tasks.CancellableTask;
 import org.elasticsearch.tasks.Task;
+import org.elasticsearch.tasks.TaskId;
 
 import java.io.IOException;
+import java.util.Map;
 
 import static org.elasticsearch.action.ValidateActions.addValidationError;
 import static org.elasticsearch.common.regex.Regex.simpleMatch;
@@ -119,4 +122,8 @@ public class ListTasksRequest extends BaseTasksRequest<ListTasksRequest> {
         return this;
     }
 
+    @Override
+    public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
+        return new CancellableTask(id, type, action, "", parentTaskId, headers);
+    }
 }

+ 10 - 2
server/src/main/java/org/elasticsearch/action/admin/cluster/node/tasks/list/TransportListTasksAction.java

@@ -24,6 +24,7 @@ import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.tasks.CancellableTask;
 import org.elasticsearch.tasks.RemovedTaskListener;
 import org.elasticsearch.tasks.Task;
+import org.elasticsearch.tasks.TaskCancelledException;
 import org.elasticsearch.tasks.TaskInfo;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.TransportService;
@@ -76,7 +77,13 @@ public class TransportListTasksAction extends TransportTasksAction<Task, ListTas
     }
 
     @Override
-    protected void processTasks(ListTasksRequest request, ActionListener<List<Task>> nodeOperation) {
+    protected void doExecute(Task task, ListTasksRequest request, ActionListener<ListTasksResponse> listener) {
+        assert task instanceof CancellableTask;
+        super.doExecute(task, request, listener);
+    }
+
+    @Override
+    protected void processTasks(CancellableTask nodeTask, ListTasksRequest request, ActionListener<List<Task>> nodeOperation) {
         if (request.getWaitForCompletion()) {
             final ListenableActionFuture<List<Task>> future = new ListenableActionFuture<>();
             final List<Task> processedTasks = new ArrayList<>();
@@ -137,8 +144,9 @@ public class TransportListTasksAction extends TransportTasksAction<Task, ListTas
                 threadPool,
                 ThreadPool.Names.SAME
             );
+            nodeTask.addListener(() -> future.onFailure(new TaskCancelledException("task cancelled")));
         } else {
-            super.processTasks(request, nodeOperation);
+            super.processTasks(nodeTask, request, nodeOperation);
         }
     }
 }

+ 2 - 1
server/src/main/java/org/elasticsearch/action/support/tasks/TransportTasksAction.java

@@ -198,7 +198,7 @@ public abstract class TransportTasksAction<
         }
     }
 
-    protected void processTasks(TasksRequest request, ActionListener<List<OperationTask>> nodeOperation) {
+    protected void processTasks(CancellableTask nodeTask, TasksRequest request, ActionListener<List<OperationTask>> nodeOperation) {
         nodeOperation.onResponse(processTasks(request));
     }
 
@@ -255,6 +255,7 @@ public abstract class TransportTasksAction<
             assert task instanceof CancellableTask;
             TasksRequest tasksRequest = request.tasksRequest;
             processTasks(
+                (CancellableTask) task,
                 tasksRequest,
                 new ChannelActionListener<NodeTasksResponse>(channel).delegateFailure(
                     (l, tasks) -> nodeOperation((CancellableTask) task, l, tasksRequest, tasks)

+ 4 - 1
server/src/main/java/org/elasticsearch/rest/action/admin/cluster/RestListTasksAction.java

@@ -18,6 +18,7 @@ import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.rest.BaseRestHandler;
 import org.elasticsearch.rest.RestChannel;
 import org.elasticsearch.rest.RestRequest;
+import org.elasticsearch.rest.action.RestCancellableNodeClient;
 import org.elasticsearch.rest.action.RestChunkedToXContentListener;
 import org.elasticsearch.tasks.TaskId;
 
@@ -49,7 +50,9 @@ public class RestListTasksAction extends BaseRestHandler {
     public RestChannelConsumer prepareRequest(final RestRequest request, final NodeClient client) throws IOException {
         final ListTasksRequest listTasksRequest = generateListTasksRequest(request);
         final String groupBy = request.param("group_by", "nodes");
-        return channel -> client.admin().cluster().listTasks(listTasksRequest, listTasksResponseListener(nodesInCluster, groupBy, channel));
+        return channel -> new RestCancellableNodeClient(client, request.getHttpChannel()).admin()
+            .cluster()
+            .listTasks(listTasksRequest, listTasksResponseListener(nodesInCluster, groupBy, channel));
     }
 
     public static ListTasksRequest generateListTasksRequest(RestRequest request) {

+ 13 - 9
test/framework/src/main/java/org/elasticsearch/action/support/ActionTestUtils.java

@@ -15,11 +15,12 @@ import org.elasticsearch.client.Response;
 import org.elasticsearch.client.ResponseListener;
 import org.elasticsearch.core.CheckedConsumer;
 import org.elasticsearch.tasks.Task;
+import org.elasticsearch.tasks.TaskId;
 import org.elasticsearch.tasks.TaskManager;
 import org.elasticsearch.transport.Transport;
 
-import static org.elasticsearch.action.support.PlainActionFuture.newFuture;
-import static org.mockito.Mockito.mock;
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
 
 public class ActionTestUtils {
 
@@ -29,10 +30,11 @@ public class ActionTestUtils {
         TransportAction<Request, Response> action,
         Request request
     ) {
-        PlainActionFuture<Response> future = newFuture();
-        Task task = mock(Task.class);
-        action.execute(task, request, future);
-        return future.actionGet();
+        return PlainActionFuture.get(
+            future -> action.execute(request.createTask(1L, "direct", action.actionName, TaskId.EMPTY_TASK_ID, Map.of()), request, future),
+            10,
+            TimeUnit.SECONDS
+        );
     }
 
     public static <Request extends ActionRequest, Response extends ActionResponse> Response executeBlockingWithTask(
@@ -41,9 +43,11 @@ public class ActionTestUtils {
         TransportAction<Request, Response> action,
         Request request
     ) {
-        PlainActionFuture<Response> future = newFuture();
-        taskManager.registerAndExecute("transport", action, request, localConnection, future);
-        return future.actionGet();
+        return PlainActionFuture.get(
+            future -> taskManager.registerAndExecute("transport", action, request, localConnection, future),
+            10,
+            TimeUnit.SECONDS
+        );
     }
 
     /**