Browse Source

ActionFilter-based REST cancellation tests (#101811)

Today it is a little tricky to write tests for REST cancellation since
we must pause the action while it is running to ensure that the
cancellation has the desired effect. Many actions have some way to
achieve this, possibly requiring some amount of trickery, but some
actions have no natural point at which they can be paused.

This commit introduces an `ActionFilter`-based approach to capture the
task just after it has been registered but before it starts executing.
All REST actions pass through the `ActionFilter` chain at least once, so
this technique should work everywhere.

This commit also removes the hooks that were added to test the
cancellability of the recovery APIs since they're no longer needed.
David Turner 1 year ago
parent
commit
6a50c1a04f

+ 0 - 92
qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/IndicesRecoveryRestCancellationIT.java

@@ -1,92 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0 and the Server Side Public License, v 1; you may not use this file except
- * in compliance with, at your election, the Elastic License 2.0 or the Server
- * Side Public License, v 1.
- */
-
-package org.elasticsearch.http;
-
-import org.apache.http.client.methods.HttpGet;
-import org.elasticsearch.action.admin.indices.recovery.RecoveryAction;
-import org.elasticsearch.action.admin.indices.recovery.TransportRecoveryAction;
-import org.elasticsearch.action.admin.indices.recovery.TransportRecoveryActionHelper;
-import org.elasticsearch.action.support.PlainActionFuture;
-import org.elasticsearch.client.Cancellable;
-import org.elasticsearch.client.Request;
-import org.elasticsearch.client.Response;
-import org.elasticsearch.core.Releasable;
-import org.elasticsearch.core.Releasables;
-
-import java.util.ArrayList;
-import java.util.List;
-import java.util.concurrent.CancellationException;
-import java.util.concurrent.Semaphore;
-
-import static org.elasticsearch.action.support.ActionTestUtils.wrapAsRestResponseListener;
-import static org.elasticsearch.test.TaskAssertions.assertAllCancellableTasksAreCancelled;
-import static org.elasticsearch.test.TaskAssertions.assertAllTasksHaveFinished;
-import static org.elasticsearch.test.TaskAssertions.awaitTaskWithPrefix;
-import static org.hamcrest.Matchers.empty;
-import static org.hamcrest.Matchers.not;
-
-public class IndicesRecoveryRestCancellationIT extends HttpSmokeTestCase {
-
-    public void testIndicesRecoveryRestCancellation() throws Exception {
-        runTest(new Request(HttpGet.METHOD_NAME, "/_recovery"));
-    }
-
-    public void testCatRecoveryRestCancellation() throws Exception {
-        runTest(new Request(HttpGet.METHOD_NAME, "/_cat/recovery"));
-    }
-
-    private void runTest(Request request) throws Exception {
-
-        createIndex("test");
-        ensureGreen("test");
-
-        final List<Semaphore> operationBlocks = new ArrayList<>();
-        for (final TransportRecoveryAction transportRecoveryAction : internalCluster().getInstances(TransportRecoveryAction.class)) {
-            final Semaphore operationBlock = new Semaphore(1);
-            operationBlocks.add(operationBlock);
-            TransportRecoveryActionHelper.setOnShardOperation(transportRecoveryAction, () -> {
-                try {
-                    operationBlock.acquire();
-                } catch (InterruptedException e) {
-                    throw new AssertionError(e);
-                }
-                operationBlock.release();
-            });
-        }
-        assertThat(operationBlocks, not(empty()));
-
-        final List<Releasable> releasables = new ArrayList<>();
-        try {
-            for (final Semaphore operationBlock : operationBlocks) {
-                operationBlock.acquire();
-                releasables.add(operationBlock::release);
-            }
-
-            final PlainActionFuture<Response> future = new PlainActionFuture<>();
-            logger.info("--> sending request");
-            final Cancellable cancellable = getRestClient().performRequestAsync(request, wrapAsRestResponseListener(future));
-
-            awaitTaskWithPrefix(RecoveryAction.NAME);
-
-            logger.info("--> waiting for at least one task to hit a block");
-            assertBusy(() -> assertTrue(operationBlocks.stream().anyMatch(Semaphore::hasQueuedThreads)));
-
-            logger.info("--> cancelling request");
-            cancellable.cancel();
-            expectThrows(CancellationException.class, future::actionGet);
-
-            assertAllCancellableTasksAreCancelled(RecoveryAction.NAME);
-        } finally {
-            Releasables.close(releasables);
-        }
-
-        assertAllTasksHaveFinished(RecoveryAction.NAME);
-    }
-
-}

+ 79 - 0
qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/RestActionCancellationIT.java

@@ -0,0 +1,79 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.http;
+
+import org.apache.http.client.methods.HttpGet;
+import org.elasticsearch.action.admin.cluster.health.ClusterHealthAction;
+import org.elasticsearch.action.admin.cluster.state.ClusterStateAction;
+import org.elasticsearch.action.admin.indices.recovery.RecoveryAction;
+import org.elasticsearch.action.support.CancellableActionTestPlugin;
+import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.client.Request;
+import org.elasticsearch.client.Response;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.util.CollectionUtils;
+import org.elasticsearch.plugins.Plugin;
+
+import java.util.Collection;
+import java.util.concurrent.CancellationException;
+import java.util.concurrent.TimeUnit;
+
+import static org.elasticsearch.action.support.ActionTestUtils.wrapAsRestResponseListener;
+import static org.elasticsearch.test.TaskAssertions.assertAllTasksHaveFinished;
+
+public class RestActionCancellationIT extends HttpSmokeTestCase {
+
+    public void testIndicesRecoveryRestCancellation() throws Exception {
+        createIndex("test");
+        ensureGreen("test");
+        runRestActionCancellationTest(new Request(HttpGet.METHOD_NAME, "/_recovery"), RecoveryAction.NAME);
+    }
+
+    public void testCatRecoveryRestCancellation() throws Exception {
+        createIndex("test");
+        ensureGreen("test");
+        runRestActionCancellationTest(new Request(HttpGet.METHOD_NAME, "/_cat/recovery"), RecoveryAction.NAME);
+    }
+
+    public void testClusterHealthRestCancellation() throws Exception {
+        runRestActionCancellationTest(new Request(HttpGet.METHOD_NAME, "/_cluster/health"), ClusterHealthAction.NAME);
+    }
+
+    public void testClusterStateRestCancellation() throws Exception {
+        runRestActionCancellationTest(new Request(HttpGet.METHOD_NAME, "/_cluster/state"), ClusterStateAction.NAME);
+    }
+
+    private void runRestActionCancellationTest(Request request, String actionName) throws Exception {
+        final var node = usually() ? internalCluster().getRandomNodeName() : internalCluster().startCoordinatingOnlyNode(Settings.EMPTY);
+
+        try (
+            var restClient = createRestClient(node);
+            var capturingAction = CancellableActionTestPlugin.capturingActionOnNode(actionName, node)
+        ) {
+            expectThrows(
+                CancellationException.class,
+                () -> PlainActionFuture.<Response, Exception>get(
+                    responseFuture -> capturingAction.captureAndCancel(
+                        restClient.performRequestAsync(request, wrapAsRestResponseListener(responseFuture))::cancel
+                    ),
+                    10,
+                    TimeUnit.SECONDS
+                )
+            );
+            assertAllTasksHaveFinished(actionName);
+        } catch (Exception e) {
+            fail(e);
+        }
+    }
+
+    @Override
+    protected Collection<Class<? extends Plugin>> nodePlugins() {
+        return CollectionUtils.appendToCopy(super.nodePlugins(), CancellableActionTestPlugin.class);
+    }
+}

+ 0 - 17
server/src/main/java/org/elasticsearch/action/admin/indices/recovery/TransportRecoveryAction.java

@@ -20,7 +20,6 @@ import org.elasticsearch.cluster.routing.ShardsIterator;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.inject.Inject;
 import org.elasticsearch.common.io.stream.StreamInput;
-import org.elasticsearch.core.Nullable;
 import org.elasticsearch.index.IndexService;
 import org.elasticsearch.index.shard.IndexShard;
 import org.elasticsearch.indices.IndicesService;
@@ -102,7 +101,6 @@ public class TransportRecoveryAction extends TransportBroadcastByNodeAction<Reco
     protected void shardOperation(RecoveryRequest request, ShardRouting shardRouting, Task task, ActionListener<RecoveryState> listener) {
         ActionListener.completeWith(listener, () -> {
             assert task instanceof CancellableTask;
-            runOnShardOperation();
             IndexService indexService = indicesService.indexServiceSafe(shardRouting.shardId().getIndex());
             IndexShard indexShard = indexService.getShard(shardRouting.shardId().id());
             return indexShard.recoveryState();
@@ -123,19 +121,4 @@ public class TransportRecoveryAction extends TransportBroadcastByNodeAction<Reco
     protected ClusterBlockException checkRequestBlock(ClusterState state, RecoveryRequest request, String[] concreteIndices) {
         return state.blocks().indicesBlockedException(ClusterBlockLevel.METADATA_READ, concreteIndices);
     }
-
-    @Nullable // unless running tests that inject extra behaviour
-    private volatile Runnable onShardOperation;
-
-    private void runOnShardOperation() {
-        final Runnable onShardOperation = this.onShardOperation;
-        if (onShardOperation != null) {
-            onShardOperation.run();
-        }
-    }
-
-    // exposed for tests: inject some extra behaviour that runs when shardOperation() is called
-    void setOnShardOperation(@Nullable Runnable onShardOperation) {
-        this.onShardOperation = onShardOperation;
-    }
 }

+ 0 - 22
test/framework/src/main/java/org/elasticsearch/action/admin/indices/recovery/TransportRecoveryActionHelper.java

@@ -1,22 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0 and the Server Side Public License, v 1; you may not use this file except
- * in compliance with, at your election, the Elastic License 2.0 or the Server
- * Side Public License, v 1.
- */
-
-package org.elasticsearch.action.admin.indices.recovery;
-
-/**
- * Helper methods for {@link TransportRecoveryAction}.
- */
-public class TransportRecoveryActionHelper {
-
-    /**
-     * Helper method for tests to call {@link TransportRecoveryAction#setOnShardOperation}.
-     */
-    public static void setOnShardOperation(TransportRecoveryAction transportRecoveryAction, Runnable setOnShardOperation) {
-        transportRecoveryAction.setOnShardOperation(setOnShardOperation);
-    }
-}

+ 155 - 0
test/framework/src/main/java/org/elasticsearch/action/support/CancellableActionTestPlugin.java

@@ -0,0 +1,155 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.action.support;
+
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.ActionRequest;
+import org.elasticsearch.action.ActionResponse;
+import org.elasticsearch.core.Releasable;
+import org.elasticsearch.plugins.ActionPlugin;
+import org.elasticsearch.plugins.Plugin;
+import org.elasticsearch.plugins.PluginsService;
+import org.elasticsearch.tasks.CancellableTask;
+import org.elasticsearch.tasks.Task;
+import org.elasticsearch.tasks.TaskCancelledException;
+import org.elasticsearch.tasks.TaskManager;
+
+import java.util.List;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.atomic.AtomicReference;
+
+import static org.elasticsearch.ExceptionsHelper.unwrapCause;
+import static org.elasticsearch.action.support.ActionTestUtils.assertNoFailureListener;
+import static org.elasticsearch.test.ESIntegTestCase.internalCluster;
+import static org.elasticsearch.test.ESTestCase.asInstanceOf;
+import static org.elasticsearch.test.ESTestCase.randomInt;
+import static org.elasticsearch.test.ESTestCase.safeAwait;
+import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.instanceOf;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+/**
+ * Utility plugin that captures the invocation of an action on a node after the task has been registered with the {@link TaskManager},
+ * cancels it (e.g. by closing the connection used for the original REST request), verifies that the corresponding task is cancelled, then
+ * lets the action execution proceed in order to verify that it fails with a {@link TaskCancelledException}. This allows to verify a few key
+ * aspects of the cancellability of tasks:
+ * <ul>
+ *     <li>The task that the request creates is cancellable.</li>
+ *     <li>The REST handler propagates cancellation to the task it starts.</li>
+ *     <li>The action implementation checks for cancellation at least once.</li>
+ * </ul>
+ * However, note that this is implemented as an {@link ActionFilter} it blocks and cancels the action before it even starts executing on the
+ * local node, so it does not verify that the cancellation is processed promptly at all stages of the execution of the action, nor that
+ * cancellations are propagated correctly to subsidiary actions.
+ */
+public class CancellableActionTestPlugin extends Plugin implements ActionPlugin {
+
+    public interface CapturingAction extends Releasable {
+        /**
+         * @param doCancel callback to invoke when the specified action has started which should cancel the action.
+         */
+        void captureAndCancel(Runnable doCancel);
+    }
+
+    /**
+     * Returns a {@link CapturingAction}, typically for use in a try-with-resources block, which can be used to capture and cancel exactly
+     * one invocation of the specified action on the specified node.
+     */
+    public static CapturingAction capturingActionOnNode(String actionName, String nodeName) {
+        final var plugins = internalCluster().getInstance(PluginsService.class, nodeName)
+            .filterPlugins(CancellableActionTestPlugin.class)
+            .toList();
+        assertThat("unique " + CancellableActionTestPlugin.class.getCanonicalName() + " plugin not found", plugins, hasSize(1));
+        return plugins.get(0).capturingAction(actionName);
+    }
+
+    private volatile String capturedActionName;
+    private final AtomicReference<SubscribableListener<Captured>> capturedRef = new AtomicReference<>();
+
+    private record Captured(Runnable doCancel, CountDownLatch countDownLatch) {}
+
+    private CapturingAction capturingAction(String actionName) {
+        final var captureListener = new SubscribableListener<Captured>();
+        capturedActionName = actionName;
+        assertTrue(capturedRef.compareAndSet(null, captureListener));
+
+        final var completionLatch = new CountDownLatch(1);
+
+        return new CapturingAction() {
+            @Override
+            public void captureAndCancel(Runnable doCancel) {
+                assertFalse(captureListener.isDone());
+                captureListener.onResponse(new Captured(doCancel, completionLatch));
+                safeAwait(completionLatch);
+            }
+
+            @Override
+            public void close() {
+                // verify that a request was indeed captured
+                assertNull(capturedRef.get());
+                // and that it completed
+                assertEquals(0, completionLatch.getCount());
+            }
+        };
+    }
+
+    @Override
+    public List<ActionFilter> getActionFilters() {
+        return List.of(new ActionFilter() {
+
+            private final int order = randomInt();
+
+            @Override
+            public int order() {
+                return order;
+            }
+
+            @Override
+            public <Request extends ActionRequest, Response extends ActionResponse> void apply(
+                Task task,
+                String action,
+                Request request,
+                ActionListener<Response> listener,
+                ActionFilterChain<Request, Response> chain
+            ) {
+                if (action.equals(capturedActionName)) {
+                    final var capturingListener = capturedRef.getAndSet(null);
+                    if (capturingListener != null) {
+                        final var cancellableTask = asInstanceOf(CancellableTask.class, task);
+                        capturingListener.addListener(assertNoFailureListener(captured -> {
+                            cancellableTask.addListener(() -> chain.proceed(task, action, request, new ActionListener<>() {
+                                @Override
+                                public void onResponse(Response response) {
+                                    fail("cancelled action should not succeed, but got " + response);
+                                }
+
+                                @Override
+                                public void onFailure(Exception e) {
+                                    assertThat(unwrapCause(e), instanceOf(TaskCancelledException.class));
+                                    listener.onFailure(e);
+                                    captured.countDownLatch().countDown();
+                                }
+                            }));
+                            assertFalse(cancellableTask.isCancelled());
+                            captured.doCancel().run();
+                        }));
+                        return;
+                    }
+                }
+
+                chain.proceed(task, action, request, listener);
+            }
+        });
+    }
+}

+ 4 - 0
test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java

@@ -2436,6 +2436,10 @@ public abstract class ESIntegTestCase extends ESTestCase {
         return createRestClient(null, "http");
     }
 
+    protected static RestClient createRestClient(String node) {
+        return createRestClient(client(node).admin().cluster().prepareNodesInfo("_local").get().getNodes(), null, "http");
+    }
+
     protected static RestClient createRestClient(RestClientBuilder.HttpClientConfigCallback httpClientConfigCallback, String protocol) {
         NodesInfoResponse nodesInfoResponse = clusterAdmin().prepareNodesInfo().get();
         assertFalse(nodesInfoResponse.hasFailures());