浏览代码

Support concurrency in the AggregationPhase (#97828)

This commit makes the aggregation phase to generate a Collector Manager that supports concurrency.
Ignacio Vera 2 年之前
父节点
当前提交
9f39e33141

+ 6 - 2
server/src/main/java/org/elasticsearch/search/SearchService.java

@@ -1259,8 +1259,12 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv
             );
             );
             context.addQuerySearchResultReleasable(aggContext);
             context.addQuerySearchResultReleasable(aggContext);
             try {
             try {
-                AggregatorFactories factories = source.aggregations().build(aggContext, null);
-                context.aggregations(new SearchContextAggregations(factories));
+                final AggregatorFactories factories = source.aggregations().build(aggContext, null);
+                final Supplier<AggregationReduceContext.Builder> supplier = () -> aggReduceContextBuilder(
+                    context::isCancelled,
+                    source.aggregations()
+                );
+                context.aggregations(new SearchContextAggregations(factories, supplier));
             } catch (IOException e) {
             } catch (IOException e) {
                 throw new AggregationInitializationException("Failed to create aggregators", e);
                 throw new AggregationInitializationException("Failed to create aggregators", e);
             }
             }

+ 68 - 31
server/src/main/java/org/elasticsearch/search/aggregations/AggregationPhase.java

@@ -8,15 +8,17 @@
 package org.elasticsearch.search.aggregations;
 package org.elasticsearch.search.aggregations;
 
 
 import org.apache.lucene.search.Collector;
 import org.apache.lucene.search.Collector;
+import org.apache.lucene.search.CollectorManager;
 import org.elasticsearch.action.search.SearchShardTask;
 import org.elasticsearch.action.search.SearchShardTask;
 import org.elasticsearch.search.aggregations.support.TimeSeriesIndexSearcher;
 import org.elasticsearch.search.aggregations.support.TimeSeriesIndexSearcher;
 import org.elasticsearch.search.internal.SearchContext;
 import org.elasticsearch.search.internal.SearchContext;
 import org.elasticsearch.search.query.QueryPhase;
 import org.elasticsearch.search.query.QueryPhase;
-import org.elasticsearch.search.query.SingleThreadCollectorManager;
 
 
 import java.io.IOException;
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.List;
 import java.util.List;
+import java.util.function.Supplier;
 
 
 /**
 /**
  * Aggregation phase of a search request, used to collect aggregations
  * Aggregation phase of a search request, used to collect aggregations
@@ -29,30 +31,50 @@ public class AggregationPhase {
         if (context.aggregations() == null) {
         if (context.aggregations() == null) {
             return;
             return;
         }
         }
-        BucketCollector bucketCollector;
+        final Supplier<Collector> collectorSupplier;
+        if (context.aggregations().isInSortOrderExecutionRequired()) {
+            executeInSortOrder(context, newBucketCollector(context));
+            collectorSupplier = () -> BucketCollector.NO_OP_COLLECTOR;
+        } else {
+            collectorSupplier = () -> newBucketCollector(context).asCollector();
+        }
+        context.aggregations().registerAggsCollectorManager(new CollectorManager<>() {
+            @Override
+            public Collector newCollector() {
+                return collectorSupplier.get();
+            }
+
+            @Override
+            public Void reduce(Collection<Collector> collectors) {
+                // we cannot run post-collection method here because we need to do it after the optional timeout
+                // has been removed from the index searcher. Therefore, we delay this processing to the
+                // AggregationPhase#execute method.
+                return null;
+            }
+        });
+    }
+
+    private static BucketCollector newBucketCollector(SearchContext context) {
         try {
         try {
-            context.aggregations().aggregators(context.aggregations().factories().createTopLevelAggregators());
-            bucketCollector = MultiBucketCollector.wrap(true, List.of(context.aggregations().aggregators()));
+            Aggregator[] aggregators = context.aggregations().factories().createTopLevelAggregators();
+            context.aggregations().aggregators(aggregators);
+            BucketCollector bucketCollector = MultiBucketCollector.wrap(true, List.of(aggregators));
             bucketCollector.preCollection();
             bucketCollector.preCollection();
+            return bucketCollector;
         } catch (IOException e) {
         } catch (IOException e) {
             throw new AggregationInitializationException("Could not initialize aggregators", e);
             throw new AggregationInitializationException("Could not initialize aggregators", e);
         }
         }
-        final Collector collector;
-        if (context.aggregations().factories().context() != null
-            && context.aggregations().factories().context().isInSortOrderExecutionRequired()) {
-            TimeSeriesIndexSearcher searcher = new TimeSeriesIndexSearcher(context.searcher(), getCancellationChecks(context));
-            searcher.setMinimumScore(context.minimumScore());
-            searcher.setProfiler(context);
-            try {
-                searcher.search(context.rewrittenQuery(), bucketCollector);
-            } catch (IOException e) {
-                throw new AggregationExecutionException("Could not perform time series aggregation", e);
-            }
-            collector = BucketCollector.NO_OP_COLLECTOR;
-        } else {
-            collector = bucketCollector.asCollector();
+    }
+
+    private static void executeInSortOrder(SearchContext context, BucketCollector collector) {
+        TimeSeriesIndexSearcher searcher = new TimeSeriesIndexSearcher(context.searcher(), getCancellationChecks(context));
+        searcher.setMinimumScore(context.minimumScore());
+        searcher.setProfiler(context);
+        try {
+            searcher.search(context.rewrittenQuery(), collector);
+        } catch (IOException e) {
+            throw new AggregationExecutionException("Could not perform time series aggregation", e);
         }
         }
-        context.aggregations().registerAggsCollectorManager(new SingleThreadCollectorManager(collector));
     }
     }
 
 
     private static List<Runnable> getCancellationChecks(SearchContext context) {
     private static List<Runnable> getCancellationChecks(SearchContext context) {
@@ -86,20 +108,35 @@ public class AggregationPhase {
             return;
             return;
         }
         }
 
 
-        Aggregator[] aggregators = context.aggregations().aggregators();
-
-        List<InternalAggregation> aggregations = new ArrayList<>(aggregators.length);
-        for (Aggregator aggregator : context.aggregations().aggregators()) {
-            try {
-                aggregator.postCollection();
-                aggregations.add(aggregator.buildTopLevel());
-            } catch (IOException e) {
-                throw new AggregationExecutionException("Failed to build aggregation [" + aggregator.name() + "]", e);
+        final List<InternalAggregations> internalAggregations = new ArrayList<>(context.aggregations().aggregators().size());
+        for (Aggregator[] aggregators : context.aggregations().aggregators()) {
+            final List<InternalAggregation> aggregations = new ArrayList<>(aggregators.length);
+            for (Aggregator aggregator : aggregators) {
+                try {
+                    aggregator.postCollection();
+                    aggregations.add(aggregator.buildTopLevel());
+                } catch (IOException e) {
+                    throw new AggregationExecutionException("Failed to build aggregation [" + aggregator.name() + "]", e);
+                }
+                // release the aggregator to claim the used bytes as we don't need it anymore
+                aggregator.releaseAggregations();
             }
             }
-            // release the aggregator to claim the used bytes as we don't need it anymore
-            aggregator.releaseAggregations();
+            internalAggregations.add(InternalAggregations.from(aggregations));
+        }
+
+        if (internalAggregations.size() > 1) {
+            // we execute this search using more than one slice. In order to keep memory requirements
+            // low, we do a partial reduction here.
+            context.queryResult()
+                .aggregations(
+                    InternalAggregations.topLevelReduce(
+                        internalAggregations,
+                        context.aggregations().getAggregationReduceContextBuilder().forPartialReduction()
+                    )
+                );
+        } else {
+            context.queryResult().aggregations(internalAggregations.get(0));
         }
         }
-        context.queryResult().aggregations(InternalAggregations.from(aggregations));
 
 
         // disable aggregations so that they don't run on next pages in case of scrolling
         // disable aggregations so that they don't run on next pages in case of scrolling
         context.aggregations(null);
         context.aggregations(null);

+ 28 - 4
server/src/main/java/org/elasticsearch/search/aggregations/SearchContextAggregations.java

@@ -10,27 +10,37 @@ package org.elasticsearch.search.aggregations;
 import org.apache.lucene.search.Collector;
 import org.apache.lucene.search.Collector;
 import org.apache.lucene.search.CollectorManager;
 import org.apache.lucene.search.CollectorManager;
 
 
+import java.util.ArrayList;
+import java.util.List;
+import java.util.function.Supplier;
+
 /**
 /**
  * The aggregation context that is part of the search context.
  * The aggregation context that is part of the search context.
  */
  */
 public class SearchContextAggregations {
 public class SearchContextAggregations {
 
 
     private final AggregatorFactories factories;
     private final AggregatorFactories factories;
-    private Aggregator[] aggregators;
+    private final Supplier<AggregationReduceContext.Builder> toAggregationReduceContextBuilder;
+    private final List<Aggregator[]> aggregators;
     private CollectorManager<Collector, Void> aggCollectorManager;
     private CollectorManager<Collector, Void> aggCollectorManager;
 
 
     /**
     /**
      * Creates a new aggregation context with the parsed aggregator factories
      * Creates a new aggregation context with the parsed aggregator factories
      */
      */
-    public SearchContextAggregations(AggregatorFactories factories) {
+    public SearchContextAggregations(
+        AggregatorFactories factories,
+        Supplier<AggregationReduceContext.Builder> toAggregationReduceContextBuilder
+    ) {
         this.factories = factories;
         this.factories = factories;
+        this.toAggregationReduceContextBuilder = toAggregationReduceContextBuilder;
+        this.aggregators = new ArrayList<>();
     }
     }
 
 
     public AggregatorFactories factories() {
     public AggregatorFactories factories() {
         return factories;
         return factories;
     }
     }
 
 
-    public Aggregator[] aggregators() {
+    public List<Aggregator[]> aggregators() {
         return aggregators;
         return aggregators;
     }
     }
 
 
@@ -40,7 +50,7 @@ public class SearchContextAggregations {
      * @param aggregators The top level aggregators of the search execution.
      * @param aggregators The top level aggregators of the search execution.
      */
      */
     public void aggregators(Aggregator[] aggregators) {
     public void aggregators(Aggregator[] aggregators) {
-        this.aggregators = aggregators;
+        this.aggregators.add(aggregators);
     }
     }
 
 
     /**
     /**
@@ -56,4 +66,18 @@ public class SearchContextAggregations {
     public CollectorManager<Collector, Void> getAggsCollectorManager() {
     public CollectorManager<Collector, Void> getAggsCollectorManager() {
         return aggCollectorManager;
         return aggCollectorManager;
     }
     }
+
+    /**
+     * Returns if the aggregations needs to execute in sort order.
+     */
+    public boolean isInSortOrderExecutionRequired() {
+        return factories.context() != null && factories.context().isInSortOrderExecutionRequired();
+    }
+
+    /**
+     * Returns a builder for the reduce context.
+     */
+    public AggregationReduceContext.Builder getAggregationReduceContextBuilder() {
+        return toAggregationReduceContextBuilder.get();
+    }
 }
 }

+ 1 - 1
server/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java

@@ -1092,7 +1092,7 @@ public class QueryPhaseTests extends IndexShardTestCase {
 
 
         executed.clear();
         executed.clear();
         context.trackTotalHitsUpTo(SearchContext.TRACK_TOTAL_HITS_DISABLED);
         context.trackTotalHitsUpTo(SearchContext.TRACK_TOTAL_HITS_DISABLED);
-        context.aggregations(new SearchContextAggregations(AggregatorFactories.EMPTY));
+        context.aggregations(new SearchContextAggregations(AggregatorFactories.EMPTY, () -> null));
         QueryPhase.executeRank(context);
         QueryPhase.executeRank(context);
         assertEquals(context.rewrittenQuery(), executed.get(0));
         assertEquals(context.rewrittenQuery(), executed.get(0));
         assertEquals(queries, executed.subList(1, executed.size()));
         assertEquals(queries, executed.subList(1, executed.size()));