Ver código fonte

Make field-caps tasks cancellable (#92051)

Nhat Nguyen 2 anos atrás
pai
commit
7b3db39701

+ 5 - 0
docs/changelog/92051.yaml

@@ -0,0 +1,5 @@
+pr: 92051
+summary: Make field-caps tasks cancellable
+area: Search
+type: enhancement
+issues: []

+ 126 - 2
server/src/internalClusterTest/java/org/elasticsearch/search/fieldcaps/FieldCapabilitiesIT.java

@@ -8,6 +8,8 @@
 
 package org.elasticsearch.search.fieldcaps;
 
+import org.apache.http.entity.ContentType;
+import org.apache.http.entity.StringEntity;
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.action.fieldcaps.FieldCapabilities;
 import org.elasticsearch.action.fieldcaps.FieldCapabilitiesAction;
@@ -17,6 +19,10 @@ import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse;
 import org.elasticsearch.action.fieldcaps.TransportFieldCapabilitiesAction;
 import org.elasticsearch.action.index.IndexRequestBuilder;
 import org.elasticsearch.action.support.ActiveShardCount;
+import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.client.Cancellable;
+import org.elasticsearch.client.Request;
+import org.elasticsearch.client.Response;
 import org.elasticsearch.cluster.metadata.IndexMetadata;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.routing.allocation.command.MoveAllocationCommand;
@@ -43,14 +49,18 @@ import org.elasticsearch.plugins.MapperPlugin;
 import org.elasticsearch.plugins.Plugin;
 import org.elasticsearch.plugins.SearchPlugin;
 import org.elasticsearch.search.DummyQueryBuilder;
+import org.elasticsearch.tasks.TaskInfo;
 import org.elasticsearch.test.ESIntegTestCase;
 import org.elasticsearch.test.transport.MockTransportService;
 import org.elasticsearch.transport.TransportService;
+import org.elasticsearch.xcontent.ObjectParser;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentParser;
 import org.junit.Before;
 
 import java.io.IOException;
+import java.nio.charset.StandardCharsets;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
@@ -58,6 +68,9 @@ import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.CancellationException;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.function.Consumer;
 import java.util.function.Function;
@@ -65,12 +78,14 @@ import java.util.function.Predicate;
 import java.util.stream.IntStream;
 
 import static java.util.Collections.singletonList;
+import static org.elasticsearch.action.support.ActionTestUtils.wrapAsRestResponseListener;
 import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
 import static org.hamcrest.Matchers.aMapWithSize;
 import static org.hamcrest.Matchers.array;
 import static org.hamcrest.Matchers.arrayContainingInAnyOrder;
 import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.greaterThanOrEqualTo;
 import static org.hamcrest.Matchers.hasKey;
 import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.not;
@@ -164,7 +179,17 @@ public class FieldCapabilitiesIT extends ESIntegTestCase {
 
     @Override
     protected Collection<Class<? extends Plugin>> nodePlugins() {
-        return List.of(TestMapperPlugin.class, ExceptionOnRewriteQueryPlugin.class);
+        return List.of(TestMapperPlugin.class, ExceptionOnRewriteQueryPlugin.class, BlockingOnRewriteQueryPlugin.class);
+    }
+
+    @Override
+    protected boolean addMockHttpTransport() {
+        return false; // enable http
+    }
+
+    @Override
+    protected boolean ignoreExternalCluster() {
+        return true;
     }
 
     public void testFieldAlias() {
@@ -641,6 +666,52 @@ public class FieldCapabilitiesIT extends ESIntegTestCase {
         assertTrue(resp.getField("extra_field").get("integer").isAggregatable());
     }
 
+    public void testCancel() throws Exception {
+        BlockingOnRewriteQueryBuilder.blockOnRewrite();
+        PlainActionFuture<Response> future = PlainActionFuture.newFuture();
+        Request restRequest = new Request("POST", "/_field_caps?fields=*");
+        restRequest.setEntity(new StringEntity("""
+                  {
+                    "index_filter": {
+                        "blocking_query": {}
+                     }
+                  }
+            """, ContentType.APPLICATION_JSON.withCharset(StandardCharsets.UTF_8)));
+        Cancellable cancellable = getRestClient().performRequestAsync(restRequest, wrapAsRestResponseListener(future));
+        logger.info("--> waiting for field-caps tasks to be started");
+        assertBusy(() -> {
+            List<TaskInfo> tasks = client().admin()
+                .cluster()
+                .prepareListTasks()
+                .setActions("indices:data/read/field_caps", "indices:data/read/field_caps[n]")
+                .get()
+                .getTasks();
+            assertThat(tasks.size(), greaterThanOrEqualTo(2));
+            for (TaskInfo task : tasks) {
+                assertTrue(task.cancellable());
+                assertFalse(task.cancelled());
+            }
+        }, 30, TimeUnit.SECONDS);
+
+        cancellable.cancel();
+        logger.info("--> waiting for field-caps tasks to be cancelled");
+        assertBusy(() -> {
+            List<TaskInfo> tasks = client().admin()
+                .cluster()
+                .prepareListTasks()
+                .setActions("indices:data/read/field_caps", "indices:data/read/field_caps[n]")
+                .get()
+                .getTasks();
+            for (TaskInfo task : tasks) {
+                assertTrue(task.cancellable());
+                assertTrue(task.cancelled());
+            }
+        }, 30, TimeUnit.SECONDS);
+
+        BlockingOnRewriteQueryBuilder.unblockOnRewrite();
+        expectThrows(CancellationException.class, future::actionGet);
+    }
+
     private void assertIndices(FieldCapabilitiesResponse response, String... indices) {
         assertNotNull(response.getIndices());
         Arrays.sort(indices);
@@ -680,7 +751,6 @@ public class FieldCapabilitiesIT extends ESIntegTestCase {
                 if (searchExecutionContext.indexMatches("*error*")) {
                     throw new IllegalArgumentException("I throw because I choose to.");
                 }
-                ;
             }
             return this;
         }
@@ -691,6 +761,60 @@ public class FieldCapabilitiesIT extends ESIntegTestCase {
         }
     }
 
+    public static class BlockingOnRewriteQueryPlugin extends Plugin implements SearchPlugin {
+
+        public BlockingOnRewriteQueryPlugin() {}
+
+        @Override
+        public List<QuerySpec<?>> getQueries() {
+            return List.of(
+                new QuerySpec<>("blocking_query", BlockingOnRewriteQueryBuilder::new, BlockingOnRewriteQueryBuilder::fromXContent)
+            );
+        }
+    }
+
+    static class BlockingOnRewriteQueryBuilder extends DummyQueryBuilder {
+        private static CountDownLatch blockingLatch = new CountDownLatch(1);
+        public static final String NAME = "blocking_query";
+
+        BlockingOnRewriteQueryBuilder() {
+
+        }
+
+        BlockingOnRewriteQueryBuilder(StreamInput in) throws IOException {
+            super(in);
+        }
+
+        static void blockOnRewrite() {
+            blockingLatch = new CountDownLatch(1);
+        }
+
+        static void unblockOnRewrite() {
+            blockingLatch.countDown();
+        }
+
+        @Override
+        protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException {
+            try {
+                blockingLatch.await();
+            } catch (InterruptedException e) {
+                throw new AssertionError(e);
+            }
+            return this;
+        }
+
+        public static BlockingOnRewriteQueryBuilder fromXContent(XContentParser parser) {
+            ObjectParser<BlockingOnRewriteQueryBuilder, Void> objectParser = new ObjectParser<>(NAME, BlockingOnRewriteQueryBuilder::new);
+            declareStandardFields(objectParser);
+            return objectParser.apply(parser, null);
+        }
+
+        @Override
+        public String getWriteableName() {
+            return NAME;
+        }
+    }
+
     public static final class TestMapperPlugin extends Plugin implements MapperPlugin {
         @Override
         public Map<String, MetadataFieldMapper.TypeParser> getMetadataMappers() {

+ 3 - 1
server/src/main/java/org/elasticsearch/action/fieldcaps/FieldCapabilitiesFetcher.java

@@ -23,6 +23,7 @@ import org.elasticsearch.search.SearchService;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.search.internal.AliasFilter;
 import org.elasticsearch.search.internal.ShardSearchRequest;
+import org.elasticsearch.tasks.CancellableTask;
 
 import java.io.IOException;
 import java.util.Collections;
@@ -45,6 +46,7 @@ class FieldCapabilitiesFetcher {
     }
 
     FieldCapabilitiesIndexResponse fetch(
+        CancellableTask task,
         ShardId shardId,
         String[] fieldPatterns,
         String[] filters,
@@ -78,7 +80,7 @@ class FieldCapabilitiesFetcher {
                     return new FieldCapabilitiesIndexResponse(shardId.getIndexName(), indexMappingHash, existing, true);
                 }
             }
-
+            task.ensureNotCancelled();
             Predicate<String> fieldPredicate = indicesService.getFieldFilter().apply(shardId.getIndexName());
             final Map<String, IndexFieldCapabilities> responseMap = retrieveFieldCaps(
                 searchExecutionContext,

+ 27 - 0
server/src/main/java/org/elasticsearch/action/fieldcaps/FieldCapabilitiesNodeRequest.java

@@ -19,6 +19,9 @@ import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.shard.ShardId;
+import org.elasticsearch.tasks.CancellableTask;
+import org.elasticsearch.tasks.Task;
+import org.elasticsearch.tasks.TaskId;
 
 import java.io.IOException;
 import java.util.Arrays;
@@ -136,6 +139,30 @@ class FieldCapabilitiesNodeRequest extends ActionRequest implements IndicesReque
         return null;
     }
 
+    @Override
+    public String getDescription() {
+        final StringBuilder stringBuilder = new StringBuilder("shards[");
+        Strings.collectionToDelimitedStringWithLimit(shardIds, ",", "", "", 1024, stringBuilder);
+        stringBuilder.append("], fields[");
+        Strings.collectionToDelimitedStringWithLimit(Arrays.asList(fields), ",", "", "", 1024, stringBuilder);
+        stringBuilder.append("], filters[");
+        stringBuilder.append(Strings.collectionToDelimitedString(Arrays.asList(filters), ","));
+        stringBuilder.append("], types[");
+        stringBuilder.append(Strings.collectionToDelimitedString(Arrays.asList(allowedTypes), ","));
+        stringBuilder.append("]");
+        return stringBuilder.toString();
+    }
+
+    @Override
+    public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
+        return new CancellableTask(id, type, action, "", parentTaskId, headers) {
+            @Override
+            public String getDescription() {
+                return FieldCapabilitiesNodeRequest.this.getDescription();
+            }
+        };
+    }
+
     @Override
     public boolean equals(Object o) {
         if (this == o) return true;

+ 12 - 0
server/src/main/java/org/elasticsearch/action/fieldcaps/FieldCapabilitiesRequest.java

@@ -18,6 +18,9 @@ import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.tasks.CancellableTask;
+import org.elasticsearch.tasks.Task;
+import org.elasticsearch.tasks.TaskId;
 import org.elasticsearch.xcontent.ToXContentObject;
 import org.elasticsearch.xcontent.XContentBuilder;
 
@@ -271,4 +274,13 @@ public final class FieldCapabilitiesRequest extends ActionRequest implements Ind
         return stringBuilder.toString();
     }
 
+    @Override
+    public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
+        return new CancellableTask(id, type, action, "", parentTaskId, headers) {
+            @Override
+            public String getDescription() {
+                return FieldCapabilitiesRequest.this.getDescription();
+            }
+        };
+    }
 }

+ 11 - 2
server/src/main/java/org/elasticsearch/action/fieldcaps/TransportFieldCapabilitiesAction.java

@@ -29,6 +29,7 @@ import org.elasticsearch.core.Tuple;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.indices.IndicesService;
 import org.elasticsearch.search.SearchService;
+import org.elasticsearch.tasks.CancellableTask;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.RemoteClusterAware;
@@ -93,6 +94,8 @@ public class TransportFieldCapabilitiesAction extends HandledTransportAction<Fie
         if (ccsCheckCompatibility) {
             checkCCSVersionCompatibility(request);
         }
+        assert task instanceof CancellableTask;
+        final CancellableTask fieldCapTask = (CancellableTask) task;
         // retrieve the initial timestamp in case the action is a cross cluster search
         long nowInMillis = request.nowInMillis() == null ? System.currentTimeMillis() : request.nowInMillis();
         final ClusterState clusterState = clusterService.state();
@@ -129,7 +132,7 @@ public class TransportFieldCapabilitiesAction extends HandledTransportAction<Fie
         final FailureCollector indexFailures = new FailureCollector();
         // One for each cluster including the local cluster
         final CountDown completionCounter = new CountDown(1 + remoteClusterIndices.size());
-        final Runnable countDown = createResponseMerger(request, completionCounter, indexResponses, indexFailures, listener);
+        final Runnable countDown = createResponseMerger(request, fieldCapTask, completionCounter, indexResponses, indexFailures, listener);
         final RequestDispatcher requestDispatcher = new RequestDispatcher(
             clusterService,
             transportService,
@@ -180,6 +183,7 @@ public class TransportFieldCapabilitiesAction extends HandledTransportAction<Fie
 
     private Runnable createResponseMerger(
         FieldCapabilitiesRequest request,
+        CancellableTask task,
         CountDown completionCounter,
         Map<String, FieldCapabilitiesIndexResponse> indexResponses,
         FailureCollector indexFailures,
@@ -193,7 +197,7 @@ public class TransportFieldCapabilitiesAction extends HandledTransportAction<Fie
                         // fork off to the management pool for merging the responses as the operation can run for longer than is acceptable
                         // on a transport thread in case of large numbers of indices and/or fields
                         threadPool.executor(ThreadPool.Names.SEARCH_COORDINATION)
-                            .submit(ActionRunnable.supply(listener, () -> merge(indexResponses, request, new ArrayList<>(failures))));
+                            .submit(ActionRunnable.supply(listener, () -> merge(indexResponses, task, request, new ArrayList<>(failures))));
                     } else {
                         listener.onResponse(
                             new FieldCapabilitiesResponse(new ArrayList<>(indexResponses.values()), new ArrayList<>(failures))
@@ -238,9 +242,11 @@ public class TransportFieldCapabilitiesAction extends HandledTransportAction<Fie
 
     private FieldCapabilitiesResponse merge(
         Map<String, FieldCapabilitiesIndexResponse> indexResponsesMap,
+        CancellableTask task,
         FieldCapabilitiesRequest request,
         List<FieldCapabilitiesFailure> failures
     ) {
+        task.ensureNotCancelled();
         final FieldCapabilitiesIndexResponse[] indexResponses = indexResponsesMap.values()
             .stream()
             .sorted(Comparator.comparing(FieldCapabilitiesIndexResponse::getIndexName))
@@ -261,6 +267,7 @@ public class TransportFieldCapabilitiesAction extends HandledTransportAction<Fie
             }
         }
 
+        task.ensureNotCancelled();
         Map<String, Map<String, FieldCapabilities>> responseMap = new HashMap<>();
         for (Map.Entry<String, Map<String, FieldCapabilities.Builder>> entry : responseMapBuilder.entrySet()) {
             Map<String, FieldCapabilities.Builder> typeMapBuilder = entry.getValue();
@@ -387,6 +394,7 @@ public class TransportFieldCapabilitiesAction extends HandledTransportAction<Fie
     private class NodeTransportHandler implements TransportRequestHandler<FieldCapabilitiesNodeRequest> {
         @Override
         public void messageReceived(FieldCapabilitiesNodeRequest request, TransportChannel channel, Task task) throws Exception {
+            assert task instanceof CancellableTask;
             final ActionListener<FieldCapabilitiesNodeResponse> listener = new ChannelActionListener<>(channel, ACTION_NODE_NAME, request);
             ActionListener.completeWith(listener, () -> {
                 final List<FieldCapabilitiesIndexResponse> allResponses = new ArrayList<>();
@@ -404,6 +412,7 @@ public class TransportFieldCapabilitiesAction extends HandledTransportAction<Fie
                     for (ShardId shardId : shardIds) {
                         try {
                             final FieldCapabilitiesIndexResponse response = fetcher.fetch(
+                                (CancellableTask) task,
                                 shardId,
                                 request.fields(),
                                 request.filters(),

+ 4 - 1
server/src/main/java/org/elasticsearch/rest/action/RestFieldCapabilitiesAction.java

@@ -67,7 +67,10 @@ public class RestFieldCapabilitiesAction extends BaseRestHandler {
             }
             fieldRequest.fields(Strings.splitStringByCommaToArray(request.param("fields")));
         }
-        return channel -> client.fieldCaps(fieldRequest, new RestChunkedToXContentListener<>(channel));
+        return channel -> {
+            RestCancellableNodeClient cancelClient = new RestCancellableNodeClient(client, request.getHttpChannel());
+            cancelClient.fieldCaps(fieldRequest, new RestChunkedToXContentListener<>(channel));
+        };
     }
 
     private static final ParseField INDEX_FILTER_FIELD = new ParseField("index_filter");

+ 28 - 0
server/src/test/java/org/elasticsearch/action/fieldcaps/FieldCapabilitiesNodeRequestTests.java

@@ -26,6 +26,8 @@ import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 
+import static org.hamcrest.Matchers.equalTo;
+
 public class FieldCapabilitiesNodeRequestTests extends AbstractWireSerializingTestCase<FieldCapabilitiesNodeRequest> {
 
     @Override
@@ -203,4 +205,30 @@ public class FieldCapabilitiesNodeRequestTests extends AbstractWireSerializingTe
             default -> throw new IllegalStateException("The test should only allow 7 parameters mutated");
         }
     }
+
+    public void testDescription() {
+        FieldCapabilitiesNodeRequest r1 = new FieldCapabilitiesNodeRequest(
+            List.of(new ShardId("index-1", "n/a", 0), new ShardId("index-2", "n/a", 3)),
+            new String[] { "field-1", "field-2" },
+            Strings.EMPTY_ARRAY,
+            Strings.EMPTY_ARRAY,
+            randomOriginalIndices(1),
+            null,
+            randomNonNegativeLong(),
+            Map.of()
+        );
+        assertThat(r1.getDescription(), equalTo("shards[[index-1][0],[index-2][3]], fields[field-1,field-2], filters[], types[]"));
+
+        FieldCapabilitiesNodeRequest r2 = new FieldCapabilitiesNodeRequest(
+            List.of(new ShardId("index-1", "n/a", 0)),
+            new String[] { "*" },
+            new String[] { "-nested", "-metadata" },
+            Strings.EMPTY_ARRAY,
+            randomOriginalIndices(1),
+            null,
+            randomNonNegativeLong(),
+            Map.of()
+        );
+        assertThat(r2.getDescription(), equalTo("shards[[index-1][0]], fields[*], filters[-nested,-metadata], types[]"));
+    }
 }