浏览代码

Introduce TwoPhaseCollector to leverage ContextIndexSearcher in AggregatorTestCase (#98835)

Ignacio Vera 2 年之前
父节点
当前提交
2cf07f603b

+ 7 - 1
server/src/main/java/org/elasticsearch/search/aggregations/BucketCollector.java

@@ -12,6 +12,7 @@ import org.apache.lucene.index.LeafReaderContext;
 import org.apache.lucene.search.Collector;
 import org.apache.lucene.search.LeafCollector;
 import org.apache.lucene.search.ScoreMode;
+import org.elasticsearch.search.internal.TwoPhaseCollector;
 
 import java.io.IOException;
 
@@ -79,7 +80,7 @@ public abstract class BucketCollector {
         return new BucketCollectorWrapper(this);
     }
 
-    public record BucketCollectorWrapper(BucketCollector bucketCollector) implements Collector {
+    public record BucketCollectorWrapper(BucketCollector bucketCollector) implements TwoPhaseCollector {
 
         @Override
         public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
@@ -90,5 +91,10 @@ public abstract class BucketCollector {
         public ScoreMode scoreMode() {
             return bucketCollector.scoreMode();
         }
+
+        @Override
+        public void doPostCollection() throws IOException {
+            bucketCollector.postCollection();
+        }
     }
 }

+ 2 - 6
server/src/main/java/org/elasticsearch/search/internal/ContextIndexSearcher.java

@@ -42,12 +42,10 @@ import org.elasticsearch.core.Releasable;
 import org.elasticsearch.lucene.util.CombinedBitSet;
 import org.elasticsearch.search.dfs.AggregatedDfs;
 import org.elasticsearch.search.profile.Timer;
-import org.elasticsearch.search.profile.query.InternalProfileCollector;
 import org.elasticsearch.search.profile.query.ProfileWeight;
 import org.elasticsearch.search.profile.query.QueryProfileBreakdown;
 import org.elasticsearch.search.profile.query.QueryProfiler;
 import org.elasticsearch.search.profile.query.QueryTimingType;
-import org.elasticsearch.search.query.QueryPhaseCollector;
 
 import java.io.IOException;
 import java.util.ArrayList;
@@ -489,10 +487,8 @@ public class ContextIndexSearcher extends IndexSearcher implements Releasable {
     }
 
     private void doAggregationPostCollection(Collector collector) throws IOException {
-        if (collector instanceof QueryPhaseCollector queryPhaseCollector) {
-            queryPhaseCollector.doPostCollection();
-        } else if (collector instanceof InternalProfileCollector profilerCollector) {
-            profilerCollector.doPostCollection();
+        if (collector instanceof TwoPhaseCollector twoPhaseCollector) {
+            twoPhaseCollector.doPostCollection();
         }
     }
 

+ 23 - 0
server/src/main/java/org/elasticsearch/search/internal/TwoPhaseCollector.java

@@ -0,0 +1,23 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.search.internal;
+
+import org.apache.lucene.search.Collector;
+
+import java.io.IOException;
+
+/** A {@link Collector} extension that allows to run a post-collection phase. This phase
+ * is run on the same thread as the collection phase. */
+public interface TwoPhaseCollector extends Collector {
+
+    /**
+     * run post-collection phase
+     */
+    void doPostCollection() throws IOException;
+}

+ 5 - 9
server/src/main/java/org/elasticsearch/search/profile/query/InternalProfileCollector.java

@@ -10,8 +10,7 @@ package org.elasticsearch.search.profile.query;
 
 import org.apache.lucene.sandbox.search.ProfilerCollector;
 import org.apache.lucene.search.Collector;
-import org.elasticsearch.search.aggregations.BucketCollector;
-import org.elasticsearch.search.query.QueryPhaseCollector;
+import org.elasticsearch.search.internal.TwoPhaseCollector;
 
 import java.io.IOException;
 import java.util.ArrayList;
@@ -27,7 +26,7 @@ import java.util.List;
  * <p>
  * InternalProfiler facilitates the linking of the Collector graph
  */
-public class InternalProfileCollector extends ProfilerCollector {
+public class InternalProfileCollector extends ProfilerCollector implements TwoPhaseCollector {
 
     private final InternalProfileCollector[] children;
     private final Collector wrappedCollector;
@@ -78,13 +77,10 @@ public class InternalProfileCollector extends ProfilerCollector {
         return new CollectorResult(getName(), getReason(), getTime(), childResults);
     }
 
+    @Override
     public void doPostCollection() throws IOException {
-        if (wrappedCollector instanceof InternalProfileCollector profileCollector) {
-            profileCollector.doPostCollection();
-        } else if (wrappedCollector instanceof QueryPhaseCollector queryPhaseCollector) {
-            queryPhaseCollector.doPostCollection();
-        } else if (wrappedCollector instanceof BucketCollector.BucketCollectorWrapper aggsCollector) {
-            aggsCollector.bucketCollector().postCollection();
+        if (wrappedCollector instanceof TwoPhaseCollector twoPhaseCollector) {
+            twoPhaseCollector.doPostCollection();
         }
     }
 }

+ 5 - 7
server/src/main/java/org/elasticsearch/search/query/QueryPhaseCollector.java

@@ -22,8 +22,7 @@ import org.apache.lucene.search.ScorerSupplier;
 import org.apache.lucene.search.Weight;
 import org.apache.lucene.util.Bits;
 import org.elasticsearch.common.lucene.Lucene;
-import org.elasticsearch.search.aggregations.BucketCollector;
-import org.elasticsearch.search.profile.query.InternalProfileCollector;
+import org.elasticsearch.search.internal.TwoPhaseCollector;
 
 import java.io.IOException;
 import java.util.Objects;
@@ -40,7 +39,7 @@ import java.util.concurrent.atomic.AtomicInteger;
  * When top docs as well as aggs are collected (because both collectors were provided), skipping low scoring hits via
  * {@link Scorable#setMinCompetitiveScore(float)} is not supported for either of the collectors.
  */
-public final class QueryPhaseCollector implements Collector {
+public final class QueryPhaseCollector implements TwoPhaseCollector {
     private final Collector aggsCollector;
     private final Collector topDocsCollector;
     private final TerminateAfterChecker terminateAfterChecker;
@@ -374,11 +373,10 @@ public final class QueryPhaseCollector implements Collector {
         }
     };
 
+    @Override
     public void doPostCollection() throws IOException {
-        if (aggsCollector instanceof BucketCollector.BucketCollectorWrapper bucketCollectorWrapper) {
-            bucketCollectorWrapper.bucketCollector().postCollection();
-        } else if (aggsCollector instanceof InternalProfileCollector profileCollector) {
-            profileCollector.doPostCollection();
+        if (aggsCollector instanceof TwoPhaseCollector twoPhaseCollector) {
+            twoPhaseCollector.doPostCollection();
         }
     }
 }

+ 11 - 18
test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java

@@ -37,7 +37,6 @@ import org.apache.lucene.store.Directory;
 import org.apache.lucene.tests.analysis.MockAnalyzer;
 import org.apache.lucene.tests.index.AssertingDirectoryReader;
 import org.apache.lucene.tests.index.RandomIndexWriter;
-import org.apache.lucene.tests.search.AssertingIndexSearcher;
 import org.apache.lucene.tests.util.LuceneTestCase;
 import org.apache.lucene.util.Accountable;
 import org.apache.lucene.util.BytesRef;
@@ -569,6 +568,7 @@ public abstract class AggregatorTestCase extends ESTestCase {
                     root.preCollection();
                     aggregators.add(root);
                     new TimeSeriesIndexSearcher(searcher, List.of()).search(rewritten, MultiBucketCollector.wrap(true, List.of(root)));
+                    root.postCollection();
                 } else {
                     CollectorManager<Collector, Void> collectorManager = new CollectorManager<>() {
                         @Override
@@ -591,7 +591,6 @@ public abstract class AggregatorTestCase extends ESTestCase {
                     }
                 }
                 for (C agg : aggregators) {
-                    agg.postCollection();
                     internalAggs.add(agg.buildTopLevel());
                 }
             } finally {
@@ -776,7 +775,6 @@ public abstract class AggregatorTestCase extends ESTestCase {
             Aggregator aggregator = createAggregator(builder, context);
             aggregator.preCollection();
             searcher.search(context.query(), aggregator.asCollector());
-            aggregator.postCollection();
             InternalAggregation r = aggregator.buildTopLevel();
             r = r.reduce(
                 List.of(r),
@@ -909,21 +907,16 @@ public abstract class AggregatorTestCase extends ESTestCase {
      * sets the IndexSearcher to run on concurrent mode.
      */
     protected IndexSearcher newIndexSearcher(DirectoryReader indexReader) throws IOException {
-        if (randomBoolean()) {
-            // this executes basic query checks and asserts that weights are normalized only once etc.
-            return new AssertingIndexSearcher(random(), indexReader);
-        } else {
-            return new ContextIndexSearcher(
-                indexReader,
-                IndexSearcher.getDefaultSimilarity(),
-                IndexSearcher.getDefaultQueryCache(),
-                IndexSearcher.getDefaultQueryCachingPolicy(),
-                randomBoolean(),
-                this.threadPoolExecutor,
-                this.threadPoolExecutor.getMaximumPoolSize(),
-                1 // forces multiple slices
-            );
-        }
+        return new ContextIndexSearcher(
+            indexReader,
+            IndexSearcher.getDefaultSimilarity(),
+            IndexSearcher.getDefaultQueryCache(),
+            IndexSearcher.getDefaultQueryCachingPolicy(),
+            randomBoolean(),
+            this.threadPoolExecutor,
+            this.threadPoolExecutor.getMaximumPoolSize(),
+            1 // forces multiple slices
+        );
     }
 
     /**

+ 5 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/FrequentItemSetsAggregationBuilder.java

@@ -181,6 +181,11 @@ public final class FrequentItemSetsAggregationBuilder extends AbstractAggregatio
         return true;
     }
 
+    @Override
+    public boolean supportsParallelCollection() {
+        return false;
+    }
+
     @Override
     protected AggregationBuilder shallowCopy(Builder factoriesBuilder, Map<String, Object> metadata) {
         return new FrequentItemSetsAggregationBuilder(name, fields, minimumSupport, minimumSetSize, size, filter, executionHint);