Browse Source

Aggs: Add cancellation checks to FilterByFilter aggregator (#130452) (#130745)

By default, the `FilterByFilterAggregator` (Used by the `"filter"` and `"filters"` aggs) was using the `DefaultBulkScorer` (From Lucene), which has no cancellation mechanism.

This PR wraps it into a `CancellableBulkScorer`, which instead calls the inner scorer with ranges, and checks cancellation between them.

This should solve cases of long-running tasks using these aggregators not being cancelled, or greatly reduce the time they take after cancellation.
Iván Cea Fontenla 3 tháng trước cách đây
mục cha
commit
aa5748c950

+ 5 - 0
docs/changelog/130452.yaml

@@ -0,0 +1,5 @@
+pr: 130452
+summary: "Aggs: Add cancellation checks to `FilterByFilter` aggregator"
+area: Aggregations
+type: bug
+issues: []

+ 216 - 0
server/src/internalClusterTest/java/org/elasticsearch/search/aggregations/bucket/FiltersCancellationIT.java

@@ -0,0 +1,216 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.search.aggregations.bucket;
+
+import org.elasticsearch.action.bulk.BulkRequestBuilder;
+import org.elasticsearch.action.search.SearchRequestBuilder;
+import org.elasticsearch.action.search.TransportSearchAction;
+import org.elasticsearch.action.support.WriteRequest;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.util.CollectionUtils;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.index.mapper.OnScriptError;
+import org.elasticsearch.plugins.Plugin;
+import org.elasticsearch.plugins.ScriptPlugin;
+import org.elasticsearch.script.LongFieldScript;
+import org.elasticsearch.script.ScriptContext;
+import org.elasticsearch.script.ScriptEngine;
+import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregator.KeyedFilter;
+import org.elasticsearch.search.lookup.SearchLookup;
+import org.elasticsearch.tasks.TaskInfo;
+import org.elasticsearch.test.ESIntegTestCase;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.json.JsonXContent;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.Semaphore;
+
+import static org.elasticsearch.index.query.QueryBuilders.termQuery;
+import static org.elasticsearch.search.aggregations.AggregationBuilders.filters;
+import static org.elasticsearch.search.aggregations.AggregationBuilders.terms;
+import static org.hamcrest.Matchers.empty;
+import static org.hamcrest.Matchers.greaterThan;
+import static org.hamcrest.Matchers.not;
+
+@ESIntegTestCase.SuiteScopeTestCase
+public class FiltersCancellationIT extends ESIntegTestCase {
+
+    private static final String INDEX = "idx";
+    private static final String PAUSE_FIELD = "pause";
+    private static final String NUMERIC_FIELD = "value";
+
+    private static final int NUM_DOCS = 100_000;
+    private static final int SEMAPHORE_PERMITS = NUM_DOCS - 1000;
+    private static final Semaphore SCRIPT_SEMAPHORE = new Semaphore(0);
+
+    @Override
+    protected Collection<Class<? extends Plugin>> nodePlugins() {
+        return CollectionUtils.appendToCopy(super.nodePlugins(), pausableFieldPluginClass());
+    }
+
+    protected Class<? extends Plugin> pausableFieldPluginClass() {
+        return PauseScriptPlugin.class;
+    }
+
+    @Override
+    public void setupSuiteScopeCluster() throws Exception {
+        try (XContentBuilder mapping = JsonXContent.contentBuilder()) {
+            mapping.startObject();
+            mapping.startObject("runtime");
+            {
+                mapping.startObject(PAUSE_FIELD);
+                {
+                    mapping.field("type", "long");
+                    mapping.startObject("script").field("source", "").field("lang", PauseScriptPlugin.PAUSE_SCRIPT_LANG).endObject();
+                }
+                mapping.endObject();
+                mapping.startObject(NUMERIC_FIELD);
+                {
+                    mapping.field("type", "long");
+                }
+                mapping.endObject();
+            }
+            mapping.endObject();
+            mapping.endObject();
+
+            client().admin().indices().prepareCreate(INDEX).setMapping(mapping).get();
+        }
+
+        int DOCS_PER_BULK = 100_000;
+        for (int i = 0; i < NUM_DOCS; i += DOCS_PER_BULK) {
+            BulkRequestBuilder bulk = client().prepareBulk().setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
+            for (int j = 0; j < DOCS_PER_BULK; j++) {
+                int docId = i + j;
+                bulk.add(prepareIndex(INDEX).setId(Integer.toString(docId)).setSource(NUMERIC_FIELD, docId));
+            }
+            bulk.get();
+        }
+
+        client().admin().indices().prepareForceMerge(INDEX).setMaxNumSegments(1).get();
+    }
+
+    public void testFiltersCountCancellation() throws Exception {
+        ensureProperCancellation(
+            client().prepareSearch(INDEX)
+                .addAggregation(
+                    filters(
+                        "filters",
+                        new KeyedFilter[] {
+                            new KeyedFilter("filter1", termQuery(PAUSE_FIELD, 1)),
+                            new KeyedFilter("filter2", termQuery(PAUSE_FIELD, 2)) }
+                    )
+                )
+        );
+    }
+
+    public void testFiltersSubAggsCancellation() throws Exception {
+        ensureProperCancellation(
+            client().prepareSearch(INDEX)
+                .addAggregation(
+                    filters(
+                        "filters",
+                        new KeyedFilter[] {
+                            new KeyedFilter("filter1", termQuery(PAUSE_FIELD, 1)),
+                            new KeyedFilter("filter2", termQuery(PAUSE_FIELD, 2)) }
+                    ).subAggregation(terms("sub").field(PAUSE_FIELD))
+                )
+        );
+    }
+
+    private void ensureProperCancellation(SearchRequestBuilder searchRequestBuilder) throws Exception {
+        var searchRequestFuture = searchRequestBuilder.setTimeout(TimeValue.timeValueSeconds(1)).execute();
+        assertFalse(searchRequestFuture.isCancelled());
+        assertFalse(searchRequestFuture.isDone());
+
+        // Check that there are search tasks running
+        assertThat(getSearchTasks(), not(empty()));
+
+        // Wait for the script field to get blocked
+        assertBusy(() -> { assertThat(SCRIPT_SEMAPHORE.getQueueLength(), greaterThan(0)); });
+
+        // Cancel the tasks
+        // Warning: Adding a waitForCompletion(true)/execute() here sometimes causes tasks to not get canceled and threads to get stuck
+        client().admin().cluster().prepareCancelTasks().setActions(TransportSearchAction.NAME + "*").get();
+
+        SCRIPT_SEMAPHORE.release(SEMAPHORE_PERMITS);
+
+        // Ensure the search request finished and that there are no more search tasks
+        assertBusy(() -> {
+            assertTrue(searchRequestFuture.isDone());
+            assertThat(getSearchTasks(), empty());
+        });
+    }
+
+    private List<TaskInfo> getSearchTasks() {
+        return client().admin()
+            .cluster()
+            .prepareListTasks()
+            .setActions(TransportSearchAction.NAME + "*")
+            .setDetailed(true)
+            .get()
+            .getTasks();
+    }
+
+    public static class PauseScriptPlugin extends Plugin implements ScriptPlugin {
+        public static final String PAUSE_SCRIPT_LANG = "pause";
+
+        @Override
+        public ScriptEngine getScriptEngine(Settings settings, Collection<ScriptContext<?>> contexts) {
+            return new ScriptEngine() {
+                @Override
+                public String getType() {
+                    return PAUSE_SCRIPT_LANG;
+                }
+
+                @Override
+                @SuppressWarnings("unchecked")
+                public <FactoryType> FactoryType compile(
+                    String name,
+                    String code,
+                    ScriptContext<FactoryType> context,
+                    Map<String, String> params
+                ) {
+                    if (context == LongFieldScript.CONTEXT) {
+                        return (FactoryType) new LongFieldScript.Factory() {
+                            @Override
+                            public LongFieldScript.LeafFactory newFactory(
+                                String fieldName,
+                                Map<String, Object> params,
+                                SearchLookup searchLookup,
+                                OnScriptError onScriptError
+                            ) {
+                                return ctx -> new LongFieldScript(fieldName, params, searchLookup, onScriptError, ctx) {
+                                    @Override
+                                    public void execute() {
+                                        try {
+                                            SCRIPT_SEMAPHORE.acquire();
+                                        } catch (InterruptedException e) {
+                                            throw new AssertionError(e);
+                                        }
+                                        emit(1);
+                                    }
+                                };
+                            }
+                        };
+                    }
+                    throw new IllegalStateException("unsupported type " + context);
+                }
+
+                @Override
+                public Set<ScriptContext<?>> getSupportedContexts() {
+                    return Set.of(LongFieldScript.CONTEXT);
+                }
+            };
+        }
+    }
+}

+ 10 - 3
server/src/main/java/org/elasticsearch/search/aggregations/bucket/filter/FilterByFilterAggregator.java

@@ -23,6 +23,7 @@ import org.elasticsearch.search.aggregations.CardinalityUpperBound;
 import org.elasticsearch.search.aggregations.LeafBucketCollector;
 import org.elasticsearch.search.aggregations.support.AggregationContext;
 import org.elasticsearch.search.runtime.AbstractScriptFieldQuery;
+import org.elasticsearch.tasks.TaskCancelledException;
 
 import java.io.IOException;
 import java.util.ArrayList;
@@ -268,7 +269,7 @@ public class FilterByFilterAggregator extends FiltersAggregator {
     private void collectCount(LeafReaderContext ctx, Bits live) throws IOException {
         Counter counter = new Counter(docCountProvider);
         for (int filterOrd = 0; filterOrd < filters().size(); filterOrd++) {
-            incrementBucketDocCount(filterOrd, filters().get(filterOrd).count(ctx, counter, live));
+            incrementBucketDocCount(filterOrd, filters().get(filterOrd).count(ctx, counter, live, this::checkCancelled));
         }
     }
 
@@ -306,11 +307,17 @@ public class FilterByFilterAggregator extends FiltersAggregator {
         MatchCollector collector = new MatchCollector();
         // create the buckets so we can call collectExistingBucket
         grow(filters().size() + 1);
-        filters().get(0).collect(aggCtx.getLeafReaderContext(), collector, live);
+        filters().get(0).collect(aggCtx.getLeafReaderContext(), collector, live, this::checkCancelled);
         for (int filterOrd = 1; filterOrd < filters().size(); filterOrd++) {
             collector.subCollector = collectableSubAggregators.getLeafCollector(aggCtx);
             collector.filterOrd = filterOrd;
-            filters().get(filterOrd).collect(aggCtx.getLeafReaderContext(), collector, live);
+            filters().get(filterOrd).collect(aggCtx.getLeafReaderContext(), collector, live, this::checkCancelled);
+        }
+    }
+
+    private void checkCancelled() {
+        if (context.isCancelled()) {
+            throw new TaskCancelledException("cancelled");
         }
     }
 

+ 7 - 4
server/src/main/java/org/elasticsearch/search/aggregations/bucket/filter/QueryToFilterAdapter.java

@@ -28,6 +28,7 @@ import org.apache.lucene.search.Weight;
 import org.apache.lucene.util.Bits;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.search.aggregations.Aggregator;
+import org.elasticsearch.search.internal.CancellableBulkScorer;
 import org.elasticsearch.xcontent.XContentBuilder;
 
 import java.io.IOException;
@@ -191,7 +192,7 @@ public class QueryToFilterAdapter {
     /**
      * Count the number of documents that match this filter in a leaf.
      */
-    long count(LeafReaderContext ctx, FiltersAggregator.Counter counter, Bits live) throws IOException {
+    long count(LeafReaderContext ctx, FiltersAggregator.Counter counter, Bits live, Runnable checkCancelled) throws IOException {
         /*
          * weight().count will return the count of matches for ctx if it can do
          * so in constant time, otherwise -1. The Weight is responsible for
@@ -215,20 +216,22 @@ public class QueryToFilterAdapter {
             // No hits in this segment.
             return 0;
         }
-        scorer.score(counter, live);
+        CancellableBulkScorer cancellableScorer = new CancellableBulkScorer(scorer, checkCancelled);
+        cancellableScorer.score(counter, live);
         return counter.readAndReset(ctx);
     }
 
     /**
      * Collect all documents that match this filter in this leaf.
      */
-    void collect(LeafReaderContext ctx, LeafCollector collector, Bits live) throws IOException {
+    void collect(LeafReaderContext ctx, LeafCollector collector, Bits live, Runnable checkCancelled) throws IOException {
         BulkScorer scorer = weight().bulkScorer(ctx);
         if (scorer == null) {
             // No hits in this segment.
             return;
         }
-        scorer.score(collector, live);
+        CancellableBulkScorer cancellableScorer = new CancellableBulkScorer(scorer, checkCancelled);
+        cancellableScorer.score(collector, live);
     }
 
     /**

+ 2 - 2
server/src/main/java/org/elasticsearch/search/internal/CancellableBulkScorer.java

@@ -20,7 +20,7 @@ import java.util.Objects;
  * A {@link BulkScorer} wrapper that runs a {@link Runnable} on a regular basis
  * so that the query can be interrupted.
  */
-final class CancellableBulkScorer extends BulkScorer {
+public final class CancellableBulkScorer extends BulkScorer {
 
     // we use the BooleanScorer window size as a base interval in order to make sure that we do not
     // slow down boolean queries
@@ -32,7 +32,7 @@ final class CancellableBulkScorer extends BulkScorer {
     private final BulkScorer scorer;
     private final Runnable checkCancelled;
 
-    CancellableBulkScorer(BulkScorer scorer, Runnable checkCancelled) {
+    public CancellableBulkScorer(BulkScorer scorer, Runnable checkCancelled) {
         this.scorer = Objects.requireNonNull(scorer);
         this.checkCancelled = Objects.requireNonNull(checkCancelled);
     }