Selaa lähdekoodia

Stop using round-tripped PipelineAggregators (#53423)

This begins to clean up how `PipelineAggregator`s and executed.
Previously, we would create the `PipelineAggregator`s on the data nodes
and embed them in the aggregation tree. When it came time to execute the
pipeline aggregation we'd use the `PipelineAggregator`s that were on the
first shard's results. This is inefficient because:
1. The data node needs to make the `PipelineAggregator` only to
   serialize it and then throw it away.
2. The coordinating node needs to deserialize all of the
   `PipelineAggregator`s even though it only needs one of them.
3. You end up with many `PipelineAggregator` instances when you only
   really *need* one per pipeline.
4. `PipelineAggregator` needs to implement serialization.

This begins to undo these by building the `PipelineAggregator`s directly
on the coordinating node and using those instead of the
`PipelineAggregator`s in the aggregtion tree. In a follow up change
we'll stop serializing the `PipelineAggregator`s to node versions that
support this behavior. And, one day, we'll be able to remove
`PipelineAggregator` from the aggregation result tree entirely.

Importantly, this doesn't change how pipeline aggregations are declared
or parsed or requested. They are still part of the `AggregationBuilder`
tree because *that* makes sense.
Nik Everett 5 vuotta sitten
vanhempi
commit
4d81edb625
35 muutettua tiedostoa jossa 458 lisäystä ja 254 poistoa
  1. 3 2
      modules/aggs-matrix-stats/src/test/java/org/elasticsearch/search/aggregations/matrix/stats/InternalMatrixStatsTests.java
  2. 35 20
      server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java
  3. 7 8
      server/src/main/java/org/elasticsearch/action/search/SearchResponseMerger.java
  4. 9 7
      server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java
  5. 1 1
      server/src/main/java/org/elasticsearch/node/Node.java
  6. 29 4
      server/src/main/java/org/elasticsearch/search/SearchService.java
  7. 10 0
      server/src/main/java/org/elasticsearch/search/aggregations/AggregationBuilder.java
  8. 25 17
      server/src/main/java/org/elasticsearch/search/aggregations/AggregatorFactories.java
  9. 46 10
      server/src/main/java/org/elasticsearch/search/aggregations/InternalAggregation.java
  10. 4 5
      server/src/main/java/org/elasticsearch/search/aggregations/InternalAggregations.java
  11. 10 8
      server/src/main/java/org/elasticsearch/search/aggregations/InternalMultiBucketAggregation.java
  12. 8 6
      server/src/main/java/org/elasticsearch/search/aggregations/bucket/InternalSingleBucketAggregation.java
  13. 44 0
      server/src/main/java/org/elasticsearch/search/aggregations/pipeline/PipelineAggregator.java
  14. 7 11
      server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java
  15. 7 14
      server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java
  16. 38 20
      server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java
  17. 16 13
      server/src/test/java/org/elasticsearch/action/search/SearchResponseMergerTests.java
  18. 32 18
      server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java
  19. 4 3
      server/src/test/java/org/elasticsearch/search/SearchServiceTests.java
  20. 14 0
      server/src/test/java/org/elasticsearch/search/aggregations/AggregatorFactoriesTests.java
  21. 1 1
      server/src/test/java/org/elasticsearch/search/aggregations/BasePipelineAggregationTestCase.java
  22. 14 13
      server/src/test/java/org/elasticsearch/search/aggregations/InternalAggregationsTests.java
  23. 2 7
      server/src/test/java/org/elasticsearch/search/aggregations/bucket/composite/InternalCompositeTests.java
  24. 3 2
      server/src/test/java/org/elasticsearch/search/aggregations/bucket/histogram/InternalHistogramTests.java
  25. 2 1
      server/src/test/java/org/elasticsearch/search/aggregations/bucket/significant/SignificanceHeuristicTests.java
  26. 4 3
      server/src/test/java/org/elasticsearch/search/aggregations/bucket/terms/TermsAggregatorTests.java
  27. 1 1
      server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java
  28. 10 14
      test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java
  29. 24 7
      test/framework/src/main/java/org/elasticsearch/test/InternalAggregationTestCase.java
  30. 7 7
      x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java
  31. 7 7
      x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/MutableSearchResponse.java
  32. 7 4
      x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/TransportSubmitAsyncSearchAction.java
  33. 7 4
      x-pack/plugin/rollup/src/main/java/org/elasticsearch/xpack/rollup/RollupResponseTranslator.java
  34. 1 2
      x-pack/plugin/rollup/src/main/java/org/elasticsearch/xpack/rollup/action/TransportRollupSearchAction.java
  35. 19 14
      x-pack/plugin/rollup/src/test/java/org/elasticsearch/xpack/rollup/RollupResponseTranslationTests.java

+ 3 - 2
modules/aggs-matrix-stats/src/test/java/org/elasticsearch/search/aggregations/matrix/stats/InternalMatrixStatsTests.java

@@ -32,6 +32,7 @@ import org.elasticsearch.search.aggregations.InternalAggregation;
 import org.elasticsearch.search.aggregations.ParsedAggregation;
 import org.elasticsearch.search.aggregations.matrix.stats.InternalMatrixStats.Fields;
 import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
+import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator.PipelineTree;
 import org.elasticsearch.test.InternalAggregationTestCase;
 
 import java.io.IOException;
@@ -162,8 +163,8 @@ public class InternalMatrixStatsTests extends InternalAggregationTestCase<Intern
 
         ScriptService mockScriptService = mockScriptService();
         MockBigArrays bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService());
-        InternalAggregation.ReduceContext context =
-            new InternalAggregation.ReduceContext(bigArrays, mockScriptService, true);
+        InternalAggregation.ReduceContext context = InternalAggregation.ReduceContext.forFinalReduction(
+                bigArrays, mockScriptService, b -> {}, PipelineTree.EMPTY);
         InternalMatrixStats reduced = (InternalMatrixStats) shardResults.get(0).reduce(shardResults, context);
         multiPassStats.assertNearlyEqual(reduced.getResults());
     }

+ 35 - 20
server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java

@@ -21,6 +21,7 @@ package org.elasticsearch.action.search;
 
 import com.carrotsearch.hppc.IntArrayList;
 import com.carrotsearch.hppc.ObjectObjectHashMap;
+
 import org.apache.lucene.index.Term;
 import org.apache.lucene.search.CollectionStatistics;
 import org.apache.lucene.search.FieldDoc;
@@ -69,17 +70,13 @@ import java.util.function.IntFunction;
 import java.util.stream.Collectors;
 
 public final class SearchPhaseController {
-
     private static final ScoreDoc[] EMPTY_DOCS = new ScoreDoc[0];
 
-    private final Function<Boolean, ReduceContext> reduceContextFunction;
+    private final Function<SearchRequest, InternalAggregation.ReduceContextBuilder> requestToAggReduceContextBuilder;
 
-    /**
-     * Constructor.
-     * @param reduceContextFunction A function that builds a context for the reduce of an {@link InternalAggregation}
-     */
-    public SearchPhaseController(Function<Boolean, ReduceContext> reduceContextFunction) {
-        this.reduceContextFunction = reduceContextFunction;
+    public SearchPhaseController(
+            Function<SearchRequest, InternalAggregation.ReduceContextBuilder> requestToAggReduceContextBuilder) {
+        this.requestToAggReduceContextBuilder = requestToAggReduceContextBuilder;
     }
 
     public AggregatedDfs aggregateDfs(Collection<DfsSearchResult> results) {
@@ -394,7 +391,18 @@ public final class SearchPhaseController {
      * @param queryResults a list of non-null query shard results
      */
     ReducedQueryPhase reducedScrollQueryPhase(Collection<? extends SearchPhaseResult> queryResults) {
-        return reducedQueryPhase(queryResults, true, SearchContext.TRACK_TOTAL_HITS_ACCURATE, true);
+        InternalAggregation.ReduceContextBuilder aggReduceContextBuilder = new InternalAggregation.ReduceContextBuilder() {
+            @Override
+            public ReduceContext forPartialReduction() {
+                throw new UnsupportedOperationException("Scroll requests don't have aggs");
+            }
+
+            @Override
+            public ReduceContext forFinalReduction() {
+                throw new UnsupportedOperationException("Scroll requests don't have aggs");
+            }
+        };
+        return reducedQueryPhase(queryResults, true, SearchContext.TRACK_TOTAL_HITS_ACCURATE, aggReduceContextBuilder, true);
     }
 
     /**
@@ -402,9 +410,11 @@ public final class SearchPhaseController {
      * @param queryResults a list of non-null query shard results
      */
     public ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResult> queryResults,
-                                               boolean isScrollRequest, int trackTotalHitsUpTo, boolean performFinalReduce) {
+                                               boolean isScrollRequest, int trackTotalHitsUpTo,
+                                               InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
+                                               boolean performFinalReduce) {
         return reducedQueryPhase(queryResults, null, new ArrayList<>(), new TopDocsStats(trackTotalHitsUpTo),
-            0, isScrollRequest, performFinalReduce);
+            0, isScrollRequest, aggReduceContextBuilder, performFinalReduce);
     }
 
     /**
@@ -421,6 +431,7 @@ public final class SearchPhaseController {
     private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResult> queryResults,
                                                 List<InternalAggregations> bufferedAggs, List<TopDocs> bufferedTopDocs,
                                                 TopDocsStats topDocsStats, int numReducePhases, boolean isScrollRequest,
+                                                InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
                                                 boolean performFinalReduce) {
         assert numReducePhases >= 0 : "num reduce phases must be >= 0 but was: " + numReducePhases;
         numReducePhases++; // increment for this phase
@@ -496,9 +507,8 @@ public final class SearchPhaseController {
             reducedSuggest = new Suggest(Suggest.reduce(groupedSuggestions));
             reducedCompletionSuggestions = reducedSuggest.filter(CompletionSuggestion.class);
         }
-        ReduceContext reduceContext = reduceContextFunction.apply(performFinalReduce);
-        final InternalAggregations aggregations = aggregationsList.isEmpty() ? null :
-            InternalAggregations.topLevelReduce(aggregationsList, reduceContext);
+        final InternalAggregations aggregations = aggregationsList.isEmpty() ? null : InternalAggregations.topLevelReduce(aggregationsList,
+                    performFinalReduce ? aggReduceContextBuilder.forFinalReduction() : aggReduceContextBuilder.forPartialReduction());
         final SearchProfileShardResults shardResults = profileResults.isEmpty() ? null : new SearchProfileShardResults(profileResults);
         final SortedTopDocs sortedTopDocs = sortDocs(isScrollRequest, queryResults, bufferedTopDocs, topDocsStats, from, size,
             reducedCompletionSuggestions);
@@ -600,6 +610,7 @@ public final class SearchPhaseController {
         private int numReducePhases = 0;
         private final TopDocsStats topDocsStats;
         private final int topNSize;
+        private final InternalAggregation.ReduceContextBuilder aggReduceContextBuilder;
         private final boolean performFinalReduce;
 
         /**
@@ -613,7 +624,9 @@ public final class SearchPhaseController {
          */
         private QueryPhaseResultConsumer(SearchProgressListener progressListener, SearchPhaseController controller,
                                          int expectedResultSize, int bufferSize, boolean hasTopDocs, boolean hasAggs,
-                                         int trackTotalHitsUpTo, int topNSize, boolean performFinalReduce) {
+                                         int trackTotalHitsUpTo, int topNSize,
+                                         InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
+                                         boolean performFinalReduce) {
             super(expectedResultSize);
             if (expectedResultSize != 1 && bufferSize < 2) {
                 throw new IllegalArgumentException("buffer size must be >= 2 if there is more than one expected result");
@@ -635,6 +648,7 @@ public final class SearchPhaseController {
             this.bufferSize = bufferSize;
             this.topDocsStats = new TopDocsStats(trackTotalHitsUpTo);
             this.topNSize = topNSize;
+            this.aggReduceContextBuilder = aggReduceContextBuilder;
             this.performFinalReduce = performFinalReduce;
         }
 
@@ -650,7 +664,7 @@ public final class SearchPhaseController {
             if (querySearchResult.isNull() == false) {
                 if (index == bufferSize) {
                     if (hasAggs) {
-                        ReduceContext reduceContext = controller.reduceContextFunction.apply(false);
+                        ReduceContext reduceContext = aggReduceContextBuilder.forPartialReduction();
                         InternalAggregations reducedAggs = InternalAggregations.topLevelReduce(Arrays.asList(aggsBuffer), reduceContext);
                         Arrays.fill(aggsBuffer, null);
                         aggsBuffer[0] = reducedAggs;
@@ -693,8 +707,8 @@ public final class SearchPhaseController {
 
         @Override
         public ReducedQueryPhase reduce() {
-            ReducedQueryPhase reducePhase = controller.reducedQueryPhase(results.asList(),
-                getRemainingAggs(), getRemainingTopDocs(), topDocsStats, numReducePhases, false, performFinalReduce);
+            ReducedQueryPhase reducePhase = controller.reducedQueryPhase(results.asList(), getRemainingAggs(), getRemainingTopDocs(),
+                    topDocsStats, numReducePhases, false, aggReduceContextBuilder, performFinalReduce);
             progressListener.notifyFinalReduce(SearchProgressListener.buildSearchShards(results.asList()),
                 reducePhase.totalHits, reducePhase.aggregations, reducePhase.numReducePhases);
             return reducePhase;
@@ -730,13 +744,14 @@ public final class SearchPhaseController {
         final boolean hasAggs = source != null && source.aggregations() != null;
         final boolean hasTopDocs = source == null || source.size() != 0;
         final int trackTotalHitsUpTo = resolveTrackTotalHits(request);
+        InternalAggregation.ReduceContextBuilder aggReduceContextBuilder = requestToAggReduceContextBuilder.apply(request);
         if (isScrollRequest == false && (hasAggs || hasTopDocs)) {
             // no incremental reduce if scroll is used - we only hit a single shard or sometimes more...
             if (request.getBatchedReduceSize() < numShards) {
                 int topNSize = getTopDocsSize(request);
                 // only use this if there are aggs and if there are more shards than we should reduce at once
                 return new QueryPhaseResultConsumer(listener, this, numShards, request.getBatchedReduceSize(), hasTopDocs, hasAggs,
-                    trackTotalHitsUpTo, topNSize, request.isFinalReduce());
+                    trackTotalHitsUpTo, topNSize, aggReduceContextBuilder, request.isFinalReduce());
             }
         }
         return new ArraySearchPhaseResults<SearchPhaseResult>(numShards) {
@@ -750,7 +765,7 @@ public final class SearchPhaseController {
             ReducedQueryPhase reduce() {
                 List<SearchPhaseResult> resultList = results.asList();
                 final ReducedQueryPhase reducePhase =
-                    reducedQueryPhase(resultList, isScrollRequest, trackTotalHitsUpTo, request.isFinalReduce());
+                    reducedQueryPhase(resultList, isScrollRequest, trackTotalHitsUpTo, aggReduceContextBuilder, request.isFinalReduce());
                 listener.notifyFinalReduce(SearchProgressListener.buildSearchShards(resultList),
                     reducePhase.totalHits, reducePhase.aggregations, reducePhase.numReducePhases);
                 return reducePhase;

+ 7 - 8
server/src/main/java/org/elasticsearch/action/search/SearchResponseMerger.java

@@ -27,13 +27,15 @@ import org.apache.lucene.search.TopFieldDocs;
 import org.apache.lucene.search.TotalHits;
 import org.apache.lucene.search.grouping.CollapseTopFieldDocs;
 import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.action.search.SearchPhaseController.TopDocsStats;
+import org.elasticsearch.action.search.SearchResponse.Clusters;
 import org.elasticsearch.action.search.TransportSearchAction.SearchTimeProvider;
 import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.search.SearchHit;
 import org.elasticsearch.search.SearchHits;
 import org.elasticsearch.search.SearchShardTarget;
-import org.elasticsearch.search.aggregations.InternalAggregation.ReduceContext;
+import org.elasticsearch.search.aggregations.InternalAggregation;
 import org.elasticsearch.search.aggregations.InternalAggregations;
 import org.elasticsearch.search.internal.InternalSearchResponse;
 import org.elasticsearch.search.profile.ProfileShardResult;
@@ -51,11 +53,8 @@ import java.util.Map;
 import java.util.Objects;
 import java.util.TreeMap;
 import java.util.concurrent.CopyOnWriteArrayList;
-import java.util.function.Function;
 
-import static org.elasticsearch.action.search.SearchPhaseController.TopDocsStats;
 import static org.elasticsearch.action.search.SearchPhaseController.mergeTopDocs;
-import static org.elasticsearch.action.search.SearchResponse.Clusters;
 
 /**
  * Merges multiple search responses into one. Used in cross-cluster search when reduction is performed locally on each cluster.
@@ -81,16 +80,16 @@ final class SearchResponseMerger {
     final int size;
     final int trackTotalHitsUpTo;
     private final SearchTimeProvider searchTimeProvider;
-    private final Function<Boolean, ReduceContext> reduceContextFunction;
+    private final InternalAggregation.ReduceContextBuilder aggReduceContextBuilder;
     private final List<SearchResponse> searchResponses = new CopyOnWriteArrayList<>();
 
     SearchResponseMerger(int from, int size, int trackTotalHitsUpTo, SearchTimeProvider searchTimeProvider,
-                         Function<Boolean, ReduceContext> reduceContextFunction) {
+                         InternalAggregation.ReduceContextBuilder aggReduceContextBuilder) {
         this.from = from;
         this.size = size;
         this.trackTotalHitsUpTo = trackTotalHitsUpTo;
         this.searchTimeProvider = Objects.requireNonNull(searchTimeProvider);
-        this.reduceContextFunction = Objects.requireNonNull(reduceContextFunction);
+        this.aggReduceContextBuilder = Objects.requireNonNull(aggReduceContextBuilder);
     }
 
     /**
@@ -196,7 +195,7 @@ final class SearchResponseMerger {
         SearchHits mergedSearchHits = topDocsToSearchHits(topDocs, topDocsStats);
         setSuggestShardIndex(shards, groupedSuggestions);
         Suggest suggest = groupedSuggestions.isEmpty() ? null : new Suggest(Suggest.reduce(groupedSuggestions));
-        InternalAggregations reducedAggs = InternalAggregations.topLevelReduce(aggs, reduceContextFunction.apply(true));
+        InternalAggregations reducedAggs = InternalAggregations.topLevelReduce(aggs, aggReduceContextBuilder.forFinalReduction());
         ShardSearchFailure[] shardFailures = failures.toArray(ShardSearchFailure.EMPTY_ARRAY);
         SearchProfileShardResults profileShardResults = profileResults.isEmpty() ? null : new SearchProfileShardResults(profileResults);
         //make failures ordering consistent between ordinary search and CCS by looking at the shard they come from

+ 9 - 7
server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java

@@ -212,9 +212,10 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
                 executeLocalSearch(task, timeProvider, searchRequest, localIndices, clusterState, listener);
             } else {
                 if (shouldMinimizeRoundtrips(searchRequest)) {
-                    ccsRemoteReduce(searchRequest, localIndices, remoteClusterIndices, timeProvider, searchService::createReduceContext,
-                        remoteClusterService, threadPool, listener,
-                        (r, l) -> executeLocalSearch(task, timeProvider, r, localIndices, clusterState, l));
+                    ccsRemoteReduce(searchRequest, localIndices, remoteClusterIndices, timeProvider,
+                            searchService.aggReduceContextBuilder(searchRequest),
+                            remoteClusterService, threadPool, listener,
+                            (r, l) -> executeLocalSearch(task, timeProvider, r, localIndices, clusterState, l));
                 } else {
                     AtomicInteger skippedClusters = new AtomicInteger(0);
                     collectSearchShards(searchRequest.indicesOptions(), searchRequest.preference(), searchRequest.routing(),
@@ -260,7 +261,7 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
     }
 
     static void ccsRemoteReduce(SearchRequest searchRequest, OriginalIndices localIndices, Map<String, OriginalIndices> remoteIndices,
-                                SearchTimeProvider timeProvider, Function<Boolean, InternalAggregation.ReduceContext> reduceContext,
+                                SearchTimeProvider timeProvider, InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
                                 RemoteClusterService remoteClusterService, ThreadPool threadPool, ActionListener<SearchResponse> listener,
                                 BiConsumer<SearchRequest, ActionListener<SearchResponse>> localSearchConsumer) {
 
@@ -298,7 +299,8 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
                 }
             });
         } else {
-            SearchResponseMerger searchResponseMerger = createSearchResponseMerger(searchRequest.source(), timeProvider, reduceContext);
+            SearchResponseMerger searchResponseMerger = createSearchResponseMerger(
+                    searchRequest.source(), timeProvider, aggReduceContextBuilder);
             AtomicInteger skippedClusters = new AtomicInteger(0);
             final AtomicReference<Exception> exceptions = new AtomicReference<>();
             int totalClusters = remoteIndices.size() + (localIndices == null ? 0 : 1);
@@ -325,7 +327,7 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
     }
 
     static SearchResponseMerger createSearchResponseMerger(SearchSourceBuilder source, SearchTimeProvider timeProvider,
-                                                           Function<Boolean, InternalAggregation.ReduceContext> reduceContextFunction) {
+                                                           InternalAggregation.ReduceContextBuilder aggReduceContextBuilder) {
         final int from;
         final int size;
         final int trackTotalHitsUpTo;
@@ -342,7 +344,7 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
             source.from(0);
             source.size(from + size);
         }
-        return new SearchResponseMerger(from, size, trackTotalHitsUpTo, timeProvider, reduceContextFunction);
+        return new SearchResponseMerger(from, size, trackTotalHitsUpTo, timeProvider, aggReduceContextBuilder);
     }
 
     static void collectSearchShards(IndicesOptions indicesOptions, String preference, String routing, AtomicInteger skippedClusters,

+ 1 - 1
server/src/main/java/org/elasticsearch/node/Node.java

@@ -575,7 +575,7 @@ public class Node implements Closeable {
                     b.bind(MetaDataCreateIndexService.class).toInstance(metaDataCreateIndexService);
                     b.bind(SearchService.class).toInstance(searchService);
                     b.bind(SearchTransportService.class).toInstance(searchTransportService);
-                    b.bind(SearchPhaseController.class).toInstance(new SearchPhaseController(searchService::createReduceContext));
+                    b.bind(SearchPhaseController.class).toInstance(new SearchPhaseController(searchService::aggReduceContextBuilder));
                     b.bind(Transport.class).toInstance(transport);
                     b.bind(TransportService.class).toInstance(transportService);
                     b.bind(NetworkService.class).toInstance(networkService);

+ 29 - 4
server/src/main/java/org/elasticsearch/search/SearchService.java

@@ -28,6 +28,7 @@ import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.ActionRunnable;
 import org.elasticsearch.action.OriginalIndices;
+import org.elasticsearch.action.search.SearchRequest;
 import org.elasticsearch.action.search.SearchShardTask;
 import org.elasticsearch.action.search.SearchType;
 import org.elasticsearch.cluster.ClusterState;
@@ -69,8 +70,10 @@ import org.elasticsearch.script.ScriptService;
 import org.elasticsearch.search.aggregations.AggregationInitializationException;
 import org.elasticsearch.search.aggregations.AggregatorFactories;
 import org.elasticsearch.search.aggregations.InternalAggregation;
+import org.elasticsearch.search.aggregations.InternalAggregation.ReduceContext;
 import org.elasticsearch.search.aggregations.MultiBucketConsumerService;
 import org.elasticsearch.search.aggregations.SearchContextAggregations;
+import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator.PipelineTree;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.search.collapse.CollapseContext;
 import org.elasticsearch.search.dfs.DfsPhase;
@@ -84,11 +87,11 @@ import org.elasticsearch.search.fetch.subphase.FetchDocValuesContext;
 import org.elasticsearch.search.fetch.subphase.ScriptFieldsContext.ScriptField;
 import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder;
 import org.elasticsearch.search.internal.AliasFilter;
-import org.elasticsearch.search.internal.SearchContextId;
 import org.elasticsearch.search.internal.InternalScrollSearchRequest;
 import org.elasticsearch.search.internal.ScrollContext;
 import org.elasticsearch.search.internal.SearchContext;
 import org.elasticsearch.search.internal.SearchContext.Lifetime;
+import org.elasticsearch.search.internal.SearchContextId;
 import org.elasticsearch.search.internal.ShardSearchRequest;
 import org.elasticsearch.search.profile.Profilers;
 import org.elasticsearch.search.query.QueryPhase;
@@ -1200,9 +1203,31 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv
         return indicesService;
     }
 
-    public InternalAggregation.ReduceContext createReduceContext(boolean finalReduce) {
-        return new InternalAggregation.ReduceContext(bigArrays, scriptService,
-            finalReduce ? multiBucketConsumerService.create() : bucketCount -> {}, finalReduce);
+    /**
+     * Returns a builder for {@link InternalAggregation.ReduceContext}. This
+     * builder retains a reference to the provided {@link SearchRequest}.
+     */
+    public InternalAggregation.ReduceContextBuilder aggReduceContextBuilder(SearchRequest request) {
+        return new InternalAggregation.ReduceContextBuilder() {
+            @Override
+            public InternalAggregation.ReduceContext forPartialReduction() {
+                return InternalAggregation.ReduceContext.forPartialReduction(bigArrays, scriptService);
+            }
+
+            @Override
+            public ReduceContext forFinalReduction() {
+                PipelineTree pipelineTree = requestToPipelineTree(request);
+                return InternalAggregation.ReduceContext.forFinalReduction(
+                        bigArrays, scriptService, multiBucketConsumerService.create(), pipelineTree);
+            }
+        };
+    }
+
+    private static PipelineTree requestToPipelineTree(SearchRequest request) {
+        if (request.source() == null || request.source().aggregations() == null) {
+            return PipelineTree.EMPTY;
+        }
+        return request.source().aggregations().buildPipelineTree();
     }
 
     static class SearchRewriteContext {

+ 10 - 0
server/src/main/java/org/elasticsearch/search/aggregations/AggregationBuilder.java

@@ -26,6 +26,8 @@ import org.elasticsearch.common.xcontent.ToXContentFragment;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.index.query.QueryRewriteContext;
 import org.elasticsearch.index.query.QueryShardContext;
+import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
+import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator.PipelineTree;
 
 import java.io.IOException;
 import java.util.Collection;
@@ -144,6 +146,14 @@ public abstract class AggregationBuilder
         return builder;
     }
 
+    /**
+     * Build a tree of {@link PipelineAggregator}s to modify the tree of
+     * aggregation results after the final reduction.
+     */
+    public PipelineTree buildPipelineTree() {
+        return factoriesBuilder.buildPipelineTree();
+    }
+
     /** Common xcontent fields shared among aggregator builders */
     public static final class CommonFields extends ParseField.CommonFields {
         public static final ParseField VALUE_TYPE = new ParseField("value_type");

+ 25 - 17
server/src/main/java/org/elasticsearch/search/aggregations/AggregatorFactories.java

@@ -31,6 +31,7 @@ import org.elasticsearch.index.query.QueryShardContext;
 import org.elasticsearch.search.aggregations.bucket.global.GlobalAggregationBuilder;
 import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
 import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
+import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator.PipelineTree;
 import org.elasticsearch.search.aggregations.support.AggregationPath;
 import org.elasticsearch.search.aggregations.support.AggregationPath.PathElement;
 import org.elasticsearch.search.internal.SearchContext;
@@ -52,6 +53,9 @@ import java.util.Set;
 import java.util.regex.Matcher;
 import java.util.regex.Pattern;
 
+import static java.util.stream.Collectors.toList;
+import static java.util.stream.Collectors.toMap;
+
 public class AggregatorFactories {
     public static final Pattern VALID_AGG_NAME = Pattern.compile("[^\\[\\]>]+");
 
@@ -232,7 +236,6 @@ public class AggregatorFactories {
         // ordered nicely, although technically order does not matter
         private final Collection<AggregationBuilder> aggregationBuilders = new LinkedHashSet<>();
         private final Collection<PipelineAggregationBuilder> pipelineAggregatorBuilders = new LinkedHashSet<>();
-        private boolean skipResolveOrder;
 
         /**
          * Create an empty builder.
@@ -295,24 +298,14 @@ public class AggregatorFactories {
             return this;
         }
 
-        /**
-         * FOR TESTING ONLY
-         */
-        Builder skipResolveOrder() {
-            this.skipResolveOrder = true;
-            return this;
-        }
-
         public AggregatorFactories build(QueryShardContext queryShardContext, AggregatorFactory parent) throws IOException {
             if (aggregationBuilders.isEmpty() && pipelineAggregatorBuilders.isEmpty()) {
                 return EMPTY;
             }
             List<PipelineAggregationBuilder> orderedpipelineAggregators = null;
-            if (skipResolveOrder) {
-                orderedpipelineAggregators = new ArrayList<>(pipelineAggregatorBuilders);
-            } else {
-                orderedpipelineAggregators = resolvePipelineAggregatorOrder(this.pipelineAggregatorBuilders, this.aggregationBuilders,
-                        parent);
+            orderedpipelineAggregators = resolvePipelineAggregatorOrder(this.pipelineAggregatorBuilders, this.aggregationBuilders);
+            for (PipelineAggregationBuilder builder : orderedpipelineAggregators) {
+                builder.validate(parent, aggregationBuilders, pipelineAggregatorBuilders);
             }
             AggregatorFactory[] aggFactories = new AggregatorFactory[aggregationBuilders.size()];
 
@@ -325,8 +318,7 @@ public class AggregatorFactories {
         }
 
         private List<PipelineAggregationBuilder> resolvePipelineAggregatorOrder(
-                Collection<PipelineAggregationBuilder> pipelineAggregatorBuilders, Collection<AggregationBuilder> aggregationBuilders,
-                AggregatorFactory parent) {
+                Collection<PipelineAggregationBuilder> pipelineAggregatorBuilders, Collection<AggregationBuilder> aggregationBuilders) {
             Map<String, PipelineAggregationBuilder> pipelineAggregatorBuildersMap = new HashMap<>();
             for (PipelineAggregationBuilder builder : pipelineAggregatorBuilders) {
                 pipelineAggregatorBuildersMap.put(builder.getName(), builder);
@@ -340,7 +332,6 @@ public class AggregatorFactories {
             Collection<PipelineAggregationBuilder> temporarilyMarked = new HashSet<>();
             while (!unmarkedBuilders.isEmpty()) {
                 PipelineAggregationBuilder builder = unmarkedBuilders.get(0);
-                builder.validate(parent, aggregationBuilders, pipelineAggregatorBuilders);
                 resolvePipelineAggregatorOrder(aggBuildersMap, pipelineAggregatorBuildersMap, orderedPipelineAggregatorrs, unmarkedBuilders,
                         temporarilyMarked, builder);
             }
@@ -494,5 +485,22 @@ public class AggregatorFactories {
                 return this;
             }
         }
+
+        /**
+         * Build a tree of {@link PipelineAggregator}s to modify the tree of
+         * aggregation results after the final reduction.
+         */
+        public PipelineTree buildPipelineTree() {
+            if (aggregationBuilders.isEmpty() && pipelineAggregatorBuilders.isEmpty()) {
+                return PipelineTree.EMPTY;
+            }
+            Map<String, PipelineTree> subTrees = aggregationBuilders.stream()
+                    .collect(toMap(AggregationBuilder::getName, AggregationBuilder::buildPipelineTree));
+            List<PipelineAggregator> aggregators = resolvePipelineAggregatorOrder(pipelineAggregatorBuilders, aggregationBuilders)
+                    .stream()
+                    .map(PipelineAggregationBuilder::create)
+                    .collect(toList());
+            return new PipelineTree(subTrees, aggregators);
+        }
     }
 }

+ 46 - 10
server/src/main/java/org/elasticsearch/search/aggregations/InternalAggregation.java

@@ -27,6 +27,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.rest.action.search.RestSearchAction;
 import org.elasticsearch.script.ScriptService;
 import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
+import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator.PipelineTree;
 import org.elasticsearch.search.aggregations.support.AggregationPath;
 
 import java.io.IOException;
@@ -37,27 +38,54 @@ import java.util.Objects;
 import java.util.function.Function;
 import java.util.function.IntConsumer;
 
+import static java.util.Objects.requireNonNull;
+
 /**
  * An internal implementation of {@link Aggregation}. Serves as a base class for all aggregation implementations.
  */
 public abstract class InternalAggregation implements Aggregation, NamedWriteable {
-
+    /**
+     * Builds {@link ReduceContext}.
+     */
+    public interface ReduceContextBuilder {
+        /**
+         * Build a {@linkplain ReduceContext} to perform a partial reduction.
+         */
+        ReduceContext forPartialReduction();
+        /**
+         * Build a {@linkplain ReduceContext} to perform the final reduction.
+         */
+        ReduceContext forFinalReduction();
+    }
     public static class ReduceContext {
-
         private final BigArrays bigArrays;
         private final ScriptService scriptService;
         private final IntConsumer multiBucketConsumer;
-        private final boolean isFinalReduce;
+        private final PipelineTree pipelineTreeRoot;
 
-        public ReduceContext(BigArrays bigArrays, ScriptService scriptService, boolean isFinalReduce) {
-            this(bigArrays, scriptService, (s) -> {}, isFinalReduce);
+        /**
+         * Build a {@linkplain ReduceContext} to perform a partial reduction.
+         */
+        public static ReduceContext forPartialReduction(BigArrays bigArrays, ScriptService scriptService) {
+            return new ReduceContext(bigArrays, scriptService, (s) -> {}, null);
+        }
+
+        /**
+         * Build a {@linkplain ReduceContext} to perform the final reduction.
+         * @param pipelineTreeRoot The root of tree of pipeline aggregations for this request
+         */
+        public static ReduceContext forFinalReduction(BigArrays bigArrays, ScriptService scriptService,
+                IntConsumer multiBucketConsumer, PipelineTree pipelineTreeRoot) {
+            return new ReduceContext(bigArrays, scriptService, multiBucketConsumer,
+                    requireNonNull(pipelineTreeRoot, "prefer EMPTY to null"));
         }
 
-        public ReduceContext(BigArrays bigArrays, ScriptService scriptService, IntConsumer multiBucketConsumer, boolean isFinalReduce) {
+        private ReduceContext(BigArrays bigArrays, ScriptService scriptService, IntConsumer multiBucketConsumer,
+                PipelineTree pipelineTreeRoot) {
             this.bigArrays = bigArrays;
             this.scriptService = scriptService;
             this.multiBucketConsumer = multiBucketConsumer;
-            this.isFinalReduce = isFinalReduce;
+            this.pipelineTreeRoot = pipelineTreeRoot;
         }
 
         /**
@@ -66,7 +94,7 @@ public abstract class InternalAggregation implements Aggregation, NamedWriteable
          * Operations that are potentially losing information can only be applied during the final reduce phase.
          */
         public boolean isFinalReduce() {
-            return isFinalReduce;
+            return pipelineTreeRoot != null;
         }
 
         public BigArrays bigArrays() {
@@ -77,6 +105,13 @@ public abstract class InternalAggregation implements Aggregation, NamedWriteable
             return scriptService;
         }
 
+        /**
+         * The root of the tree of pipeline aggregations for this request.
+         */
+        public PipelineTree pipelineTreeRoot() {
+            return pipelineTreeRoot;
+        }
+
         /**
          * Adds {@code count} buckets to the global count for the request and fails if this number is greater than
          * the maximum number of buckets allowed in a response
@@ -155,9 +190,10 @@ public abstract class InternalAggregation implements Aggregation, NamedWriteable
      * Creates the output from all pipeline aggs that this aggregation is associated with.  Should only
      * be called after all aggregations have been fully reduced
      */
-    public InternalAggregation reducePipelines(InternalAggregation reducedAggs, ReduceContext reduceContext) {
+    public InternalAggregation reducePipelines(
+            InternalAggregation reducedAggs, ReduceContext reduceContext, PipelineTree pipelinesForThisAgg) {
         assert reduceContext.isFinalReduce();
-        for (PipelineAggregator pipelineAggregator : pipelineAggregators) {
+        for (PipelineAggregator pipelineAggregator : pipelinesForThisAgg.aggregators()) {
             reducedAggs = pipelineAggregator.reduce(reducedAggs, reduceContext);
         }
         return reducedAggs;

+ 4 - 5
server/src/main/java/org/elasticsearch/search/aggregations/InternalAggregations.java

@@ -139,13 +139,12 @@ public final class InternalAggregations extends Aggregations implements Writeabl
         if (context.isFinalReduce()) {
             List<InternalAggregation> reducedInternalAggs = reduced.getInternalAggregations();
             reducedInternalAggs = reducedInternalAggs.stream()
-                .map(agg -> agg.reducePipelines(agg, context))
+                .map(agg -> agg.reducePipelines(agg, context, context.pipelineTreeRoot().subTree(agg.getName())))
                 .collect(Collectors.toList());
 
-            List<SiblingPipelineAggregator> topLevelPipelineAggregators = aggregationsList.get(0).getTopLevelPipelineAggregators();
-            for (SiblingPipelineAggregator pipelineAggregator : topLevelPipelineAggregators) {
-                InternalAggregation newAgg
-                    = pipelineAggregator.doReduce(new InternalAggregations(reducedInternalAggs), context);
+            for (PipelineAggregator pipelineAggregator : context.pipelineTreeRoot().aggregators()) {
+                SiblingPipelineAggregator sib = (SiblingPipelineAggregator) pipelineAggregator;
+                InternalAggregation newAgg = sib.doReduce(new InternalAggregations(reducedInternalAggs), context);
                 reducedInternalAggs.add(newAgg);
             }
             return new InternalAggregations(reducedInternalAggs);

+ 10 - 8
server/src/main/java/org/elasticsearch/search/aggregations/InternalMultiBucketAggregation.java

@@ -24,6 +24,7 @@ import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation;
 import org.elasticsearch.search.aggregations.bucket.SingleBucketAggregation;
 import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
+import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator.PipelineTree;
 
 import java.io.IOException;
 import java.util.ArrayList;
@@ -144,15 +145,15 @@ public abstract class InternalMultiBucketAggregation<A extends InternalMultiBuck
     }
 
     /**
-     * Unlike {@link InternalAggregation#reducePipelines(InternalAggregation, ReduceContext)}, a multi-bucket
-     * agg needs to first reduce the buckets (and their parent pipelines) before allowing sibling pipelines
-     * to materialize
+     * Amulti-bucket agg needs to first reduce the buckets and *their* pipelines
+     * before allowing sibling pipelines to materialize.
      */
     @Override
-    public final InternalAggregation reducePipelines(InternalAggregation reducedAggs, ReduceContext reduceContext) {
+    public final InternalAggregation reducePipelines(
+            InternalAggregation reducedAggs, ReduceContext reduceContext, PipelineTree pipelineTree) {
         assert reduceContext.isFinalReduce();
-        List<B> materializedBuckets = reducePipelineBuckets(reduceContext);
-        return super.reducePipelines(create(materializedBuckets), reduceContext);
+        List<B> materializedBuckets = reducePipelineBuckets(reduceContext, pipelineTree);
+        return super.reducePipelines(create(materializedBuckets), reduceContext, pipelineTree);
     }
 
     @Override
@@ -172,12 +173,13 @@ public abstract class InternalMultiBucketAggregation<A extends InternalMultiBuck
         return modified ? create(newBuckets) : this;
     }
 
-    private List<B> reducePipelineBuckets(ReduceContext reduceContext) {
+    private List<B> reducePipelineBuckets(ReduceContext reduceContext, PipelineTree pipelineTree) {
         List<B> reducedBuckets = new ArrayList<>();
         for (B bucket : getBuckets()) {
             List<InternalAggregation> aggs = new ArrayList<>();
             for (Aggregation agg : bucket.getAggregations()) {
-                aggs.add(((InternalAggregation)agg).reducePipelines((InternalAggregation)agg, reduceContext));
+                PipelineTree subTree = pipelineTree.subTree(agg.getName());
+                aggs.add(((InternalAggregation)agg).reducePipelines((InternalAggregation)agg, reduceContext, subTree));
             }
             reducedBuckets.add(createBucket(new InternalAggregations(aggs), bucket));
         }

+ 8 - 6
server/src/main/java/org/elasticsearch/search/aggregations/bucket/InternalSingleBucketAggregation.java

@@ -25,6 +25,7 @@ import org.elasticsearch.search.aggregations.Aggregation;
 import org.elasticsearch.search.aggregations.InternalAggregation;
 import org.elasticsearch.search.aggregations.InternalAggregations;
 import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
+import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator.PipelineTree;
 import org.elasticsearch.search.aggregations.support.AggregationPath;
 
 import java.io.IOException;
@@ -113,19 +114,20 @@ public abstract class InternalSingleBucketAggregation extends InternalAggregatio
     }
 
     /**
-     * Unlike {@link InternalAggregation#reducePipelines(InternalAggregation, ReduceContext)}, a single-bucket
-     * agg needs to first reduce the aggs in it's bucket (and their parent pipelines) before allowing sibling pipelines
-     * to reduce
+     * Amulti-bucket agg needs to first reduce the buckets and *their* pipelines
+     * before allowing sibling pipelines to materialize.
      */
     @Override
-    public final InternalAggregation reducePipelines(InternalAggregation reducedAggs, ReduceContext reduceContext) {
+    public final InternalAggregation reducePipelines(
+            InternalAggregation reducedAggs, ReduceContext reduceContext, PipelineTree pipelineTree) {
         assert reduceContext.isFinalReduce();
         List<InternalAggregation> aggs = new ArrayList<>();
         for (Aggregation agg : getAggregations().asList()) {
-            aggs.add(((InternalAggregation)agg).reducePipelines((InternalAggregation)agg, reduceContext));
+            PipelineTree subTree = pipelineTree.subTree(agg.getName());
+            aggs.add(((InternalAggregation)agg).reducePipelines((InternalAggregation)agg, reduceContext, subTree));
         }
         InternalAggregations reducedSubAggs = new InternalAggregations(aggs);
-        return super.reducePipelines(create(reducedSubAggs), reduceContext);
+        return super.reducePipelines(create(reducedSubAggs), reduceContext, pipelineTree);
     }
 
     @Override

+ 44 - 0
server/src/main/java/org/elasticsearch/search/aggregations/pipeline/PipelineAggregator.java

@@ -30,8 +30,12 @@ import org.elasticsearch.search.aggregations.InternalAggregation.ReduceContext;
 import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
 
 import java.io.IOException;
+import java.util.List;
 import java.util.Map;
 
+import static java.util.Collections.emptyList;
+import static java.util.Collections.emptyMap;
+
 public abstract class PipelineAggregator implements NamedWriteable {
     /**
      * Parse the {@link PipelineAggregationBuilder} from a {@link XContentParser}.
@@ -57,6 +61,46 @@ public abstract class PipelineAggregator implements NamedWriteable {
                 throws IOException;
     }
 
+    /**
+     * Tree of {@link PipelineAggregator}s to modify a tree of aggregations
+     * after their final reduction.
+     */
+    public static class PipelineTree {
+        /**
+         * An empty tree of {@link PipelineAggregator}s.
+         */
+        public static final PipelineTree EMPTY = new PipelineTree(emptyMap(), emptyList());
+
+        private final Map<String, PipelineTree> subTrees;
+        private final List<PipelineAggregator> aggregators;
+
+        public PipelineTree(Map<String, PipelineTree> subTrees, List<PipelineAggregator> aggregators) {
+            this.subTrees = subTrees;
+            this.aggregators = aggregators;
+        }
+
+        /**
+         * The {@link PipelineAggregator}s for the aggregation at this
+         * position in the tree.
+         */
+        public List<PipelineAggregator> aggregators() {
+            return aggregators;
+        }
+
+        /**
+         * Get the sub-tree at for the named sub-aggregation or {@link #EMPTY}
+         * if there are no pipeline aggragations for that sub-aggregator.
+         */
+        public PipelineTree subTree(String name) {
+            return subTrees.getOrDefault(name, EMPTY);
+        }
+
+        @Override
+        public String toString() {
+            return "PipelineTree[" + aggregators + "," + subTrees + "]";
+        }
+    }
+
     private String name;
     private String[] bucketsPaths;
     private Map<String, Object> metaData;

+ 7 - 11
server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java

@@ -27,18 +27,17 @@ import org.apache.lucene.store.MockDirectoryWrapper;
 import org.elasticsearch.action.OriginalIndices;
 import org.elasticsearch.common.UUIDs;
 import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
-import org.elasticsearch.common.util.BigArrays;
 import org.elasticsearch.common.util.concurrent.AtomicArray;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.search.DocValueFormat;
 import org.elasticsearch.search.SearchPhaseResult;
 import org.elasticsearch.search.SearchShardTarget;
-import org.elasticsearch.search.aggregations.InternalAggregation;
 import org.elasticsearch.search.dfs.DfsSearchResult;
 import org.elasticsearch.search.internal.SearchContextId;
 import org.elasticsearch.search.query.QuerySearchRequest;
 import org.elasticsearch.search.query.QuerySearchResult;
 import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.InternalAggregationTestCase;
 import org.elasticsearch.transport.Transport;
 
 import java.io.IOException;
@@ -63,8 +62,6 @@ public class DfsQueryPhaseTests extends ESTestCase {
         results.get(0).termsStatistics(new Term[0], new TermStatistics[0]);
         results.get(1).termsStatistics(new Term[0], new TermStatistics[0]);
 
-        SearchPhaseController controller = new SearchPhaseController(
-            (b) -> new InternalAggregation.ReduceContext(BigArrays.NON_RECYCLING_INSTANCE, null, b));
         SearchTransportService searchTransportService = new SearchTransportService(null, null) {
             @Override
             public void sendExecuteQuery(Transport.Connection connection, QuerySearchRequest request, SearchTask task,
@@ -92,7 +89,7 @@ public class DfsQueryPhaseTests extends ESTestCase {
         };
         MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2);
         mockSearchPhaseContext.searchTransport = searchTransportService;
-        DfsQueryPhase phase = new DfsQueryPhase(results, controller,
+        DfsQueryPhase phase = new DfsQueryPhase(results, searchPhaseController(),
             (response) -> new SearchPhase("test") {
             @Override
             public void run() throws IOException {
@@ -125,8 +122,6 @@ public class DfsQueryPhaseTests extends ESTestCase {
         results.get(0).termsStatistics(new Term[0], new TermStatistics[0]);
         results.get(1).termsStatistics(new Term[0], new TermStatistics[0]);
 
-        SearchPhaseController controller = new SearchPhaseController(
-            (b) -> new InternalAggregation.ReduceContext(BigArrays.NON_RECYCLING_INSTANCE, null, b));
         SearchTransportService searchTransportService = new SearchTransportService(null, null) {
             @Override
             public void sendExecuteQuery(Transport.Connection connection, QuerySearchRequest request, SearchTask task,
@@ -148,7 +143,7 @@ public class DfsQueryPhaseTests extends ESTestCase {
         };
         MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2);
         mockSearchPhaseContext.searchTransport = searchTransportService;
-        DfsQueryPhase phase = new DfsQueryPhase(results, controller,
+        DfsQueryPhase phase = new DfsQueryPhase(results, searchPhaseController(),
             (response) -> new SearchPhase("test") {
                 @Override
                 public void run() throws IOException {
@@ -184,8 +179,6 @@ public class DfsQueryPhaseTests extends ESTestCase {
         results.get(0).termsStatistics(new Term[0], new TermStatistics[0]);
         results.get(1).termsStatistics(new Term[0], new TermStatistics[0]);
 
-        SearchPhaseController controller = new SearchPhaseController(
-            (b) -> new InternalAggregation.ReduceContext(BigArrays.NON_RECYCLING_INSTANCE, null, b));
         SearchTransportService searchTransportService = new SearchTransportService(null, null) {
             @Override
             public void sendExecuteQuery(Transport.Connection connection, QuerySearchRequest request, SearchTask task,
@@ -207,7 +200,7 @@ public class DfsQueryPhaseTests extends ESTestCase {
         };
         MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2);
         mockSearchPhaseContext.searchTransport = searchTransportService;
-        DfsQueryPhase phase = new DfsQueryPhase(results, controller,
+        DfsQueryPhase phase = new DfsQueryPhase(results, searchPhaseController(),
             (response) -> new SearchPhase("test") {
                 @Override
                 public void run() throws IOException {
@@ -219,4 +212,7 @@ public class DfsQueryPhaseTests extends ESTestCase {
         assertTrue(mockSearchPhaseContext.releasedSearchContexts.isEmpty()); // phase execution will clean up on the contexts
     }
 
+    private SearchPhaseController searchPhaseController() {
+        return new SearchPhaseController(request -> InternalAggregationTestCase.emptyReduceContextBuilder());
+    }
  }

+ 7 - 14
server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java

@@ -26,20 +26,19 @@ import org.elasticsearch.action.OriginalIndices;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.common.UUIDs;
 import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
-import org.elasticsearch.common.util.BigArrays;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.search.DocValueFormat;
 import org.elasticsearch.search.SearchHit;
 import org.elasticsearch.search.SearchHits;
 import org.elasticsearch.search.SearchPhaseResult;
 import org.elasticsearch.search.SearchShardTarget;
-import org.elasticsearch.search.aggregations.InternalAggregation;
 import org.elasticsearch.search.fetch.FetchSearchResult;
 import org.elasticsearch.search.fetch.QueryFetchSearchResult;
 import org.elasticsearch.search.fetch.ShardFetchSearchRequest;
 import org.elasticsearch.search.internal.SearchContextId;
 import org.elasticsearch.search.query.QuerySearchResult;
 import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.InternalAggregationTestCase;
 import org.elasticsearch.transport.Transport;
 
 import java.util.concurrent.CountDownLatch;
@@ -50,8 +49,7 @@ import static org.elasticsearch.action.search.SearchProgressListener.NOOP;
 public class FetchSearchPhaseTests extends ESTestCase {
 
     public void testShortcutQueryAndFetchOptimization() {
-        SearchPhaseController controller = new SearchPhaseController(
-            (b) -> new InternalAggregation.ReduceContext(BigArrays.NON_RECYCLING_INSTANCE, null, b));
+        SearchPhaseController controller = new SearchPhaseController(s -> InternalAggregationTestCase.emptyReduceContextBuilder());
         MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(1);
         ArraySearchPhaseResults<SearchPhaseResult> results = controller.newSearchPhaseResults(NOOP, mockSearchPhaseContext.getRequest(), 1);
         boolean hasHits = randomBoolean();
@@ -94,8 +92,7 @@ public class FetchSearchPhaseTests extends ESTestCase {
 
     public void testFetchTwoDocument() {
         MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2);
-        SearchPhaseController controller = new SearchPhaseController(
-            (b) -> new InternalAggregation.ReduceContext(BigArrays.NON_RECYCLING_INSTANCE, null, b));
+        SearchPhaseController controller = new SearchPhaseController(s -> InternalAggregationTestCase.emptyReduceContextBuilder());
         ArraySearchPhaseResults<SearchPhaseResult> results = controller.newSearchPhaseResults(NOOP, mockSearchPhaseContext.getRequest(), 2);
         int resultSetSize = randomIntBetween(2, 10);
         final SearchContextId ctx1 = new SearchContextId(UUIDs.randomBase64UUID(), 123);
@@ -154,8 +151,7 @@ public class FetchSearchPhaseTests extends ESTestCase {
 
     public void testFailFetchOneDoc() {
         MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2);
-        SearchPhaseController controller = new SearchPhaseController(
-            (b) -> new InternalAggregation.ReduceContext(BigArrays.NON_RECYCLING_INSTANCE, null, b));
+        SearchPhaseController controller = new SearchPhaseController(s -> InternalAggregationTestCase.emptyReduceContextBuilder());
         ArraySearchPhaseResults<SearchPhaseResult> results =
             controller.newSearchPhaseResults(NOOP, mockSearchPhaseContext.getRequest(), 2);
         int resultSetSize = randomIntBetween(2, 10);
@@ -218,8 +214,7 @@ public class FetchSearchPhaseTests extends ESTestCase {
         int resultSetSize = randomIntBetween(0, 100);
         // we use at least 2 hits otherwise this is subject to single shard optimization and we trip an assert...
         int numHits = randomIntBetween(2, 100); // also numshards --> 1 hit per shard
-        SearchPhaseController controller = new SearchPhaseController(
-            (b) -> new InternalAggregation.ReduceContext(BigArrays.NON_RECYCLING_INSTANCE, null, b));
+        SearchPhaseController controller = new SearchPhaseController(s -> InternalAggregationTestCase.emptyReduceContextBuilder());
         MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(numHits);
         ArraySearchPhaseResults<SearchPhaseResult> results = controller.newSearchPhaseResults(NOOP,
             mockSearchPhaseContext.getRequest(), numHits);
@@ -276,8 +271,7 @@ public class FetchSearchPhaseTests extends ESTestCase {
 
     public void testExceptionFailsPhase() {
         MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2);
-        SearchPhaseController controller = new SearchPhaseController(
-            (b) -> new InternalAggregation.ReduceContext(BigArrays.NON_RECYCLING_INSTANCE, null, b));
+        SearchPhaseController controller = new SearchPhaseController(s -> InternalAggregationTestCase.emptyReduceContextBuilder());
         ArraySearchPhaseResults<SearchPhaseResult> results =
             controller.newSearchPhaseResults(NOOP, mockSearchPhaseContext.getRequest(), 2);
         int resultSetSize = randomIntBetween(2, 10);
@@ -333,8 +327,7 @@ public class FetchSearchPhaseTests extends ESTestCase {
 
     public void testCleanupIrrelevantContexts() { // contexts that are not fetched should be cleaned up
         MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2);
-        SearchPhaseController controller = new SearchPhaseController(
-            (b) -> new InternalAggregation.ReduceContext(BigArrays.NON_RECYCLING_INSTANCE, null, b));
+        SearchPhaseController controller = new SearchPhaseController(s -> InternalAggregationTestCase.emptyReduceContextBuilder());
         ArraySearchPhaseResults<SearchPhaseResult> results =
             controller.newSearchPhaseResults(NOOP, mockSearchPhaseContext.getRequest(), 2);
         int resultSetSize = 1;

+ 38 - 20
server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java

@@ -20,6 +20,7 @@
 package org.elasticsearch.action.search;
 
 import com.carrotsearch.randomizedtesting.RandomizedContext;
+
 import org.apache.lucene.search.FieldDoc;
 import org.apache.lucene.search.ScoreDoc;
 import org.apache.lucene.search.SortField;
@@ -45,8 +46,10 @@ import org.elasticsearch.search.SearchPhaseResult;
 import org.elasticsearch.search.SearchShardTarget;
 import org.elasticsearch.search.aggregations.AggregationBuilders;
 import org.elasticsearch.search.aggregations.InternalAggregation;
+import org.elasticsearch.search.aggregations.InternalAggregation.ReduceContext;
 import org.elasticsearch.search.aggregations.InternalAggregations;
 import org.elasticsearch.search.aggregations.metrics.InternalMax;
+import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator.PipelineTree;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.search.fetch.FetchSearchResult;
 import org.elasticsearch.search.internal.InternalSearchResponse;
@@ -59,6 +62,7 @@ import org.elasticsearch.search.suggest.completion.CompletionSuggestion;
 import org.elasticsearch.search.suggest.phrase.PhraseSuggestion;
 import org.elasticsearch.search.suggest.term.TermSuggestion;
 import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.InternalAggregationTestCase;
 import org.junit.Before;
 
 import java.util.ArrayList;
@@ -89,11 +93,19 @@ public class SearchPhaseControllerTests extends ESTestCase {
     @Before
     public void setup() {
         reductions = new CopyOnWriteArrayList<>();
-        searchPhaseController = new SearchPhaseController(
-            (finalReduce) -> {
-                reductions.add(finalReduce);
-                return new InternalAggregation.ReduceContext(BigArrays.NON_RECYCLING_INSTANCE, null, finalReduce);
-            });
+        searchPhaseController = new SearchPhaseController(s -> new InternalAggregation.ReduceContextBuilder() {
+            @Override
+            public ReduceContext forPartialReduction() {
+                reductions.add(false);
+                return InternalAggregation.ReduceContext.forPartialReduction(BigArrays.NON_RECYCLING_INSTANCE, null);
+            }
+
+            public ReduceContext forFinalReduction() {
+                reductions.add(true);
+                return InternalAggregation.ReduceContext.forFinalReduction(
+                        BigArrays.NON_RECYCLING_INSTANCE, null, b -> {}, PipelineTree.EMPTY);
+            };
+        });
     }
 
     public void testSortDocs() {
@@ -176,8 +188,8 @@ public class SearchPhaseControllerTests extends ESTestCase {
         int queryResultSize = randomBoolean() ? 0 : randomIntBetween(1, nShards * 2);
         AtomicArray<SearchPhaseResult> queryResults = generateQueryResults(nShards, suggestions, queryResultSize, false);
         for (int trackTotalHits : new int[] {SearchContext.TRACK_TOTAL_HITS_DISABLED, SearchContext.TRACK_TOTAL_HITS_ACCURATE}) {
-            SearchPhaseController.ReducedQueryPhase reducedQueryPhase =
-                searchPhaseController.reducedQueryPhase(queryResults.asList(), false, trackTotalHits, true);
+            SearchPhaseController.ReducedQueryPhase reducedQueryPhase = searchPhaseController.reducedQueryPhase(
+                    queryResults.asList(), false, trackTotalHits, InternalAggregationTestCase.emptyReduceContextBuilder(), true);
             AtomicArray<SearchPhaseResult> fetchResults = generateFetchResults(nShards,
                 reducedQueryPhase.sortedTopDocs.scoreDocs, reducedQueryPhase.suggest);
             InternalSearchResponse mergedResponse = searchPhaseController.merge(false,
@@ -221,8 +233,8 @@ public class SearchPhaseControllerTests extends ESTestCase {
      * Generate random query results received from the provided number of shards, including the provided
      * number of search hits and randomly generated completion suggestions based on the name and size of the provided ones.
      * Note that <code>shardIndex</code> is already set to the generated completion suggestions to simulate what
-     * {@link SearchPhaseController#reducedQueryPhase(Collection, boolean, int, boolean)} does, meaning that the returned query results
-     * can be fed directly to
+     * {@link SearchPhaseController#reducedQueryPhase(Collection, boolean, int, InternalAggregation.ReduceContextBuilder, boolean)} does,
+     * meaning that the returned query results can be fed directly to
      * {@link SearchPhaseController#sortDocs(boolean, Collection, Collection, SearchPhaseController.TopDocsStats, int, int, List)}
      */
     private static AtomicArray<SearchPhaseResult> generateQueryResults(int nShards, List<CompletionSuggestion> suggestions,
@@ -431,7 +443,7 @@ public class SearchPhaseControllerTests extends ESTestCase {
         SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce();
         assertEquals(numTotalReducePhases, reduce.numReducePhases);
         assertEquals(numTotalReducePhases, reductions.size());
-        assertFinalReduction(request);
+        assertAggReduction(request);
         InternalMax max = (InternalMax) reduce.aggregations.asList().get(0);
         assertEquals(3.0D, max.getValue(), 0.0D);
         assertFalse(reduce.sortedTopDocs.isSortedByField);
@@ -475,7 +487,7 @@ public class SearchPhaseControllerTests extends ESTestCase {
             threads[i].join();
         }
         SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce();
-        assertFinalReduction(request);
+        assertAggReduction(request);
         InternalMax internalMax = (InternalMax) reduce.aggregations.asList().get(0);
         assertEquals(max.get(), internalMax.getValue(), 0.0D);
         assertEquals(1, reduce.sortedTopDocs.scoreDocs.length);
@@ -512,7 +524,7 @@ public class SearchPhaseControllerTests extends ESTestCase {
             consumer.consumeResult(result);
         }
         SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce();
-        assertFinalReduction(request);
+        assertAggReduction(request);
         InternalMax internalMax = (InternalMax) reduce.aggregations.asList().get(0);
         assertEquals(max.get(), internalMax.getValue(), 0.0D);
         assertEquals(0, reduce.sortedTopDocs.scoreDocs.length);
@@ -547,7 +559,7 @@ public class SearchPhaseControllerTests extends ESTestCase {
             consumer.consumeResult(result);
         }
         SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce();
-        assertFinalReduction(request);
+        assertAggReduction(request);
         assertEquals(1, reduce.sortedTopDocs.scoreDocs.length);
         assertEquals(max.get(), reduce.maxScore, 0.0f);
         assertEquals(expectedNumResults, reduce.totalHits.value);
@@ -558,9 +570,15 @@ public class SearchPhaseControllerTests extends ESTestCase {
         assertNull(reduce.sortedTopDocs.collapseValues);
     }
 
-    private void assertFinalReduction(SearchRequest searchRequest) {
-        assertThat(reductions.size(), greaterThanOrEqualTo(1));
-        assertEquals(searchRequest.isFinalReduce(), reductions.get(reductions.size() - 1));
+    private void assertAggReduction(SearchRequest searchRequest) {
+        if (searchRequest.source() == null || searchRequest.source().aggregations() == null ||
+                searchRequest.source().aggregations().getAggregatorFactories().isEmpty()) {
+            // When there aren't any aggregations we don't perform any aggregation reductions.
+            assertThat(reductions.size(), equalTo(0));
+        } else {
+            assertThat(reductions.size(), greaterThanOrEqualTo(1));
+            assertEquals(searchRequest.isFinalReduce(), reductions.get(reductions.size() - 1));
+        }
     }
 
     public void testNewSearchPhaseResults() {
@@ -655,7 +673,7 @@ public class SearchPhaseControllerTests extends ESTestCase {
             consumer.consumeResult(result);
         }
         SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce();
-        assertFinalReduction(request);
+        assertAggReduction(request);
         assertEquals(Math.min(expectedNumResults, size), reduce.sortedTopDocs.scoreDocs.length);
         assertEquals(expectedNumResults, reduce.totalHits.value);
         assertEquals(max.get(), ((FieldDoc)reduce.sortedTopDocs.scoreDocs[0]).fields[0]);
@@ -693,7 +711,7 @@ public class SearchPhaseControllerTests extends ESTestCase {
             consumer.consumeResult(result);
         }
         SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce();
-        assertFinalReduction(request);
+        assertAggReduction(request);
         assertEquals(3, reduce.sortedTopDocs.scoreDocs.length);
         assertEquals(expectedNumResults, reduce.totalHits.value);
         assertEquals(a, ((FieldDoc)reduce.sortedTopDocs.scoreDocs[0]).fields[0]);
@@ -787,7 +805,7 @@ public class SearchPhaseControllerTests extends ESTestCase {
             CompletionSuggestion.Entry.Option option = completion.getOptions().get(0);
             assertEquals(maxScoreCompletion, option.getScore(), 0f);
         }
-        assertFinalReduction(request);
+        assertAggReduction(request);
         assertEquals(1, reduce.sortedTopDocs.scoreDocs.length);
         assertEquals(maxScoreCompletion, reduce.sortedTopDocs.scoreDocs[0].score, 0f);
         assertEquals(0, reduce.sortedTopDocs.scoreDocs[0].doc);
@@ -862,7 +880,7 @@ public class SearchPhaseControllerTests extends ESTestCase {
                 threads[i].join();
             }
             SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce();
-            assertFinalReduction(request);
+            assertAggReduction(request);
             InternalMax internalMax = (InternalMax) reduce.aggregations.asList().get(0);
             assertEquals(max.get(), internalMax.getValue(), 0.0D);
             assertEquals(1, reduce.sortedTopDocs.scoreDocs.length);

+ 16 - 13
server/src/test/java/org/elasticsearch/action/search/SearchResponseMergerTests.java

@@ -32,7 +32,6 @@ import org.elasticsearch.search.DocValueFormat;
 import org.elasticsearch.search.SearchHit;
 import org.elasticsearch.search.SearchHits;
 import org.elasticsearch.search.SearchShardTarget;
-import org.elasticsearch.search.aggregations.InternalAggregation;
 import org.elasticsearch.search.aggregations.InternalAggregations;
 import org.elasticsearch.search.aggregations.bucket.range.InternalDateRange;
 import org.elasticsearch.search.aggregations.bucket.range.Range;
@@ -64,6 +63,7 @@ import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.TimeUnit;
 
+import static org.elasticsearch.test.InternalAggregationTestCase.emptyReduceContextBuilder;
 import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.greaterThan;
 import static org.hamcrest.Matchers.greaterThanOrEqualTo;
@@ -97,7 +97,7 @@ public class SearchResponseMergerTests extends ESTestCase {
         long currentRelativeTime = randomLong();
         SearchTimeProvider timeProvider = new SearchTimeProvider(randomLong(), 0, () -> currentRelativeTime);
         SearchResponseMerger merger = new SearchResponseMerger(randomIntBetween(0, 1000), randomIntBetween(0, 10000),
-            SearchContext.TRACK_TOTAL_HITS_ACCURATE, timeProvider, flag -> null);
+            SearchContext.TRACK_TOTAL_HITS_ACCURATE, timeProvider, emptyReduceContextBuilder());
         for (int i = 0; i < numResponses; i++) {
             SearchResponse searchResponse = new SearchResponse(InternalSearchResponse.empty(), null, 1, 1, 0, randomLong(),
                 ShardSearchFailure.EMPTY_ARRAY, SearchResponseTests.randomClusters());
@@ -111,7 +111,7 @@ public class SearchResponseMergerTests extends ESTestCase {
     public void testMergeShardFailures() throws InterruptedException {
         SearchTimeProvider searchTimeProvider = new SearchTimeProvider(0, 0, () -> 0);
         SearchResponseMerger merger = new SearchResponseMerger(0, 0, SearchContext.TRACK_TOTAL_HITS_ACCURATE,
-            searchTimeProvider, flag -> null);
+            searchTimeProvider, emptyReduceContextBuilder());
         PriorityQueue<Tuple<SearchShardTarget, ShardSearchFailure>> priorityQueue = new PriorityQueue<>(Comparator.comparing(Tuple::v1,
             (o1, o2) -> {
                 int compareTo = o1.getShardId().compareTo(o2.getShardId());
@@ -159,7 +159,7 @@ public class SearchResponseMergerTests extends ESTestCase {
     public void testMergeShardFailuresNullShardTarget() throws InterruptedException {
         SearchTimeProvider searchTimeProvider = new SearchTimeProvider(0, 0, () -> 0);
         SearchResponseMerger merger = new SearchResponseMerger(0, 0, SearchContext.TRACK_TOTAL_HITS_ACCURATE,
-            searchTimeProvider, flag -> null);
+            searchTimeProvider, emptyReduceContextBuilder());
         PriorityQueue<Tuple<ShardId, ShardSearchFailure>> priorityQueue = new PriorityQueue<>(Comparator.comparing(Tuple::v1));
         for (int i = 0; i < numResponses; i++) {
             int numFailures = randomIntBetween(1, 10);
@@ -197,7 +197,7 @@ public class SearchResponseMergerTests extends ESTestCase {
     public void testMergeShardFailuresNullShardId() throws InterruptedException {
         SearchTimeProvider searchTimeProvider = new SearchTimeProvider(0, 0, () -> 0);
         SearchResponseMerger merger = new SearchResponseMerger(0, 0, SearchContext.TRACK_TOTAL_HITS_ACCURATE,
-            searchTimeProvider, flag -> null);
+            searchTimeProvider, emptyReduceContextBuilder());
         List<ShardSearchFailure> expectedFailures = new ArrayList<>();
         for (int i = 0; i < numResponses; i++) {
             int numFailures = randomIntBetween(1, 50);
@@ -220,7 +220,7 @@ public class SearchResponseMergerTests extends ESTestCase {
     public void testMergeProfileResults() throws InterruptedException {
         SearchTimeProvider searchTimeProvider = new SearchTimeProvider(0, 0, () -> 0);
         SearchResponseMerger merger = new SearchResponseMerger(0, 0, SearchContext.TRACK_TOTAL_HITS_ACCURATE,
-            searchTimeProvider, flag -> null);
+            searchTimeProvider, emptyReduceContextBuilder());
         Map<String, ProfileShardResult> expectedProfile = new HashMap<>();
         for (int i = 0; i < numResponses; i++) {
             SearchProfileShardResults profile = SearchProfileShardResultsTests.createTestItem();
@@ -247,7 +247,8 @@ public class SearchResponseMergerTests extends ESTestCase {
     public void testMergeCompletionSuggestions()  throws InterruptedException {
         String suggestionName = randomAlphaOfLengthBetween(4, 8);
         int size = randomIntBetween(1, 100);
-        SearchResponseMerger searchResponseMerger = new SearchResponseMerger(0, 0, 0, new SearchTimeProvider(0, 0, () -> 0), flag -> null);
+        SearchResponseMerger searchResponseMerger = new SearchResponseMerger(0, 0, 0, new SearchTimeProvider(0, 0, () -> 0),
+                emptyReduceContextBuilder());
         for (int i = 0; i < numResponses; i++) {
             List<Suggest.Suggestion<? extends Suggest.Suggestion.Entry<? extends Suggest.Suggestion.Entry.Option>>> suggestions =
                 new ArrayList<>();
@@ -296,7 +297,8 @@ public class SearchResponseMergerTests extends ESTestCase {
     public void testMergeCompletionSuggestionsTieBreak()  throws InterruptedException {
         String suggestionName = randomAlphaOfLengthBetween(4, 8);
         int size = randomIntBetween(1, 100);
-        SearchResponseMerger searchResponseMerger = new SearchResponseMerger(0, 0, 0, new SearchTimeProvider(0, 0, () -> 0), flag -> null);
+        SearchResponseMerger searchResponseMerger = new SearchResponseMerger(0, 0, 0, new SearchTimeProvider(0, 0, () -> 0),
+                emptyReduceContextBuilder());
         for (int i = 0; i < numResponses; i++) {
             List<Suggest.Suggestion<? extends Suggest.Suggestion.Entry<? extends Suggest.Suggestion.Entry.Option>>> suggestions =
                 new ArrayList<>();
@@ -351,7 +353,7 @@ public class SearchResponseMergerTests extends ESTestCase {
 
     public void testMergeAggs() throws InterruptedException {
         SearchResponseMerger searchResponseMerger = new SearchResponseMerger(0, 0, 0, new SearchTimeProvider(0, 0, () -> 0),
-            flag -> new InternalAggregation.ReduceContext(null, null, flag));
+                emptyReduceContextBuilder());
         String maxAggName = randomAlphaOfLengthBetween(5, 8);
         String rangeAggName = randomAlphaOfLengthBetween(5, 8);
         int totalCount = 0;
@@ -429,7 +431,8 @@ public class SearchResponseMergerTests extends ESTestCase {
         TotalHits.Relation totalHitsRelation = randomTrackTotalHits.v2();
 
         PriorityQueue<SearchHit> priorityQueue = new PriorityQueue<>(new SearchHitComparator(sortFields));
-        SearchResponseMerger searchResponseMerger = new SearchResponseMerger(from, size, trackTotalHitsUpTo, timeProvider, flag -> null);
+        SearchResponseMerger searchResponseMerger = new SearchResponseMerger(
+                from, size, trackTotalHitsUpTo, timeProvider, emptyReduceContextBuilder());
 
         TotalHits expectedTotalHits = null;
         int expectedTotal = 0;
@@ -556,7 +559,7 @@ public class SearchResponseMergerTests extends ESTestCase {
     public void testMergeNoResponsesAdded() {
         long currentRelativeTime = randomLong();
         final SearchTimeProvider timeProvider = new SearchTimeProvider(randomLong(), 0, () -> currentRelativeTime);
-        SearchResponseMerger merger = new SearchResponseMerger(0, 10, Integer.MAX_VALUE, timeProvider, flag -> null);
+        SearchResponseMerger merger = new SearchResponseMerger(0, 10, Integer.MAX_VALUE, timeProvider, emptyReduceContextBuilder());
         SearchResponse.Clusters clusters = SearchResponseTests.randomClusters();
         assertEquals(0, merger.numResponses());
         SearchResponse response = merger.getMergedResponse(clusters);
@@ -583,7 +586,7 @@ public class SearchResponseMergerTests extends ESTestCase {
     public void testMergeEmptySearchHitsWithNonEmpty() {
         long currentRelativeTime = randomLong();
         final SearchTimeProvider timeProvider = new SearchTimeProvider(randomLong(), 0, () -> currentRelativeTime);
-        SearchResponseMerger merger = new SearchResponseMerger(0, 10, Integer.MAX_VALUE, timeProvider, flag -> null);
+        SearchResponseMerger merger = new SearchResponseMerger(0, 10, Integer.MAX_VALUE, timeProvider, emptyReduceContextBuilder());
         SearchResponse.Clusters clusters = SearchResponseTests.randomClusters();
         int numFields = randomIntBetween(1, 3);
         SortField[] sortFields = new SortField[numFields];
@@ -626,7 +629,7 @@ public class SearchResponseMergerTests extends ESTestCase {
         Tuple<Integer, TotalHits.Relation> randomTrackTotalHits = randomTrackTotalHits();
         int trackTotalHitsUpTo = randomTrackTotalHits.v1();
         TotalHits.Relation totalHitsRelation = randomTrackTotalHits.v2();
-        SearchResponseMerger merger = new SearchResponseMerger(0, 10, trackTotalHitsUpTo, timeProvider, flag -> null);
+        SearchResponseMerger merger = new SearchResponseMerger(0, 10, trackTotalHitsUpTo, timeProvider, emptyReduceContextBuilder());
         int numResponses = randomIntBetween(1, 5);
         TotalHits expectedTotalHits = null;
         for (int i = 0; i < numResponses; i++) {

+ 32 - 18
server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java

@@ -52,6 +52,7 @@ import org.elasticsearch.search.SearchHit;
 import org.elasticsearch.search.SearchHits;
 import org.elasticsearch.search.aggregations.InternalAggregation;
 import org.elasticsearch.search.aggregations.InternalAggregations;
+import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator.PipelineTree;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.search.collapse.CollapseBuilder;
 import org.elasticsearch.search.internal.AliasFilter;
@@ -89,6 +90,8 @@ import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.BiFunction;
 import java.util.function.Function;
 
+import static java.util.Collections.emptyList;
+import static java.util.Collections.emptyMap;
 import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch;
 import static org.hamcrest.CoreMatchers.containsString;
 import static org.hamcrest.CoreMatchers.instanceOf;
@@ -385,8 +388,8 @@ public class TransportSearchActionTests extends ESTestCase {
             AtomicReference<Exception> failure = new AtomicReference<>();
             LatchedActionListener<SearchResponse> listener = new LatchedActionListener<>(
                 ActionListener.wrap(r -> fail("no response expected"), failure::set), latch);
-            TransportSearchAction.ccsRemoteReduce(searchRequest, localIndices, remoteIndicesByCluster, timeProvider, reduceContext,
-                remoteClusterService, threadPool, listener, (r, l) -> setOnce.set(Tuple.tuple(r, l)));
+            TransportSearchAction.ccsRemoteReduce(searchRequest, localIndices, remoteIndicesByCluster, timeProvider,
+                    aggReduceContextBuilder(), remoteClusterService, threadPool, listener, (r, l) -> setOnce.set(Tuple.tuple(r, l)));
             if (localIndices == null) {
                 assertNull(setOnce.get());
             } else {
@@ -419,8 +422,6 @@ public class TransportSearchActionTests extends ESTestCase {
         OriginalIndices localIndices = local ? new OriginalIndices(new String[]{"index"}, SearchRequest.DEFAULT_INDICES_OPTIONS) : null;
         int totalClusters = numClusters + (local ? 1 : 0);
         TransportSearchAction.SearchTimeProvider timeProvider = new TransportSearchAction.SearchTimeProvider(0, 0, () -> 0);
-        Function<Boolean, InternalAggregation.ReduceContext> reduceContext =
-            finalReduce -> new InternalAggregation.ReduceContext(null, null, finalReduce);
         try (MockTransportService service = MockTransportService.createNewService(settings, Version.CURRENT, threadPool, null)) {
             service.start();
             service.acceptIncomingRequests();
@@ -432,8 +433,8 @@ public class TransportSearchActionTests extends ESTestCase {
                 AtomicReference<SearchResponse> response = new AtomicReference<>();
                 LatchedActionListener<SearchResponse> listener = new LatchedActionListener<>(
                     ActionListener.wrap(response::set, e -> fail("no failures expected")), latch);
-                TransportSearchAction.ccsRemoteReduce(searchRequest, localIndices, remoteIndicesByCluster, timeProvider, reduceContext,
-                    remoteClusterService, threadPool, listener, (r, l) -> setOnce.set(Tuple.tuple(r, l)));
+                TransportSearchAction.ccsRemoteReduce(searchRequest, localIndices, remoteIndicesByCluster, timeProvider,
+                        aggReduceContextBuilder(), remoteClusterService, threadPool, listener, (r, l) -> setOnce.set(Tuple.tuple(r, l)));
                 if (localIndices == null) {
                     assertNull(setOnce.get());
                 } else {
@@ -458,8 +459,8 @@ public class TransportSearchActionTests extends ESTestCase {
                 AtomicReference<Exception> failure = new AtomicReference<>();
                 LatchedActionListener<SearchResponse> listener = new LatchedActionListener<>(
                     ActionListener.wrap(r -> fail("no response expected"), failure::set), latch);
-                TransportSearchAction.ccsRemoteReduce(searchRequest, localIndices, remoteIndicesByCluster, timeProvider, reduceContext,
-                    remoteClusterService, threadPool, listener, (r, l) -> setOnce.set(Tuple.tuple(r, l)));
+                TransportSearchAction.ccsRemoteReduce(searchRequest, localIndices, remoteIndicesByCluster, timeProvider,
+                        aggReduceContextBuilder(), remoteClusterService, threadPool, listener, (r, l) -> setOnce.set(Tuple.tuple(r, l)));
                 if (localIndices == null) {
                     assertNull(setOnce.get());
                 } else {
@@ -505,8 +506,8 @@ public class TransportSearchActionTests extends ESTestCase {
                 AtomicReference<Exception> failure = new AtomicReference<>();
                 LatchedActionListener<SearchResponse> listener = new LatchedActionListener<>(
                     ActionListener.wrap(r -> fail("no response expected"), failure::set), latch);
-                TransportSearchAction.ccsRemoteReduce(searchRequest, localIndices, remoteIndicesByCluster, timeProvider, reduceContext,
-                    remoteClusterService, threadPool, listener, (r, l) -> setOnce.set(Tuple.tuple(r, l)));
+                TransportSearchAction.ccsRemoteReduce(searchRequest, localIndices, remoteIndicesByCluster, timeProvider,
+                        aggReduceContextBuilder(), remoteClusterService, threadPool, listener, (r, l) -> setOnce.set(Tuple.tuple(r, l)));
                 if (localIndices == null) {
                     assertNull(setOnce.get());
                 } else {
@@ -534,8 +535,8 @@ public class TransportSearchActionTests extends ESTestCase {
                 AtomicReference<SearchResponse> response = new AtomicReference<>();
                 LatchedActionListener<SearchResponse> listener = new LatchedActionListener<>(
                     ActionListener.wrap(response::set, e -> fail("no failures expected")), latch);
-                TransportSearchAction.ccsRemoteReduce(searchRequest, localIndices, remoteIndicesByCluster, timeProvider, reduceContext,
-                    remoteClusterService, threadPool, listener, (r, l) -> setOnce.set(Tuple.tuple(r, l)));
+                TransportSearchAction.ccsRemoteReduce(searchRequest, localIndices, remoteIndicesByCluster, timeProvider, 
+                        aggReduceContextBuilder(), remoteClusterService, threadPool, listener, (r, l) -> setOnce.set(Tuple.tuple(r, l)));
                 if (localIndices == null) {
                     assertNull(setOnce.get());
                 } else {
@@ -574,8 +575,8 @@ public class TransportSearchActionTests extends ESTestCase {
                 AtomicReference<SearchResponse> response = new AtomicReference<>();
                 LatchedActionListener<SearchResponse> listener = new LatchedActionListener<>(
                     ActionListener.wrap(response::set, e -> fail("no failures expected")), latch);
-                TransportSearchAction.ccsRemoteReduce(searchRequest, localIndices, remoteIndicesByCluster, timeProvider, reduceContext,
-                    remoteClusterService, threadPool, listener, (r, l) -> setOnce.set(Tuple.tuple(r, l)));
+                TransportSearchAction.ccsRemoteReduce(searchRequest, localIndices, remoteIndicesByCluster, timeProvider,
+                        aggReduceContextBuilder(), remoteClusterService, threadPool, listener, (r, l) -> setOnce.set(Tuple.tuple(r, l)));
                 if (localIndices == null) {
                     assertNull(setOnce.get());
                 } else {
@@ -751,13 +752,12 @@ public class TransportSearchActionTests extends ESTestCase {
 
     public void testCreateSearchResponseMerger() {
         TransportSearchAction.SearchTimeProvider timeProvider = new TransportSearchAction.SearchTimeProvider(0, 0, () -> 0);
-        Function<Boolean, InternalAggregation.ReduceContext> reduceContext = flag -> null;
         {
             SearchSourceBuilder source = new SearchSourceBuilder();
             assertEquals(-1, source.size());
             assertEquals(-1, source.from());
             assertNull(source.trackTotalHitsUpTo());
-            SearchResponseMerger merger = TransportSearchAction.createSearchResponseMerger(source, timeProvider, reduceContext);
+            SearchResponseMerger merger = TransportSearchAction.createSearchResponseMerger(source, timeProvider, aggReduceContextBuilder());
             assertEquals(0, merger.from);
             assertEquals(10, merger.size);
             assertEquals(SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO, merger.trackTotalHitsUpTo);
@@ -766,7 +766,7 @@ public class TransportSearchActionTests extends ESTestCase {
             assertNull(source.trackTotalHitsUpTo());
         }
         {
-            SearchResponseMerger merger = TransportSearchAction.createSearchResponseMerger(null, timeProvider, reduceContext);
+            SearchResponseMerger merger = TransportSearchAction.createSearchResponseMerger(null, timeProvider, aggReduceContextBuilder());
             assertEquals(0, merger.from);
             assertEquals(10, merger.size);
             assertEquals(SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO, merger.trackTotalHitsUpTo);
@@ -779,7 +779,7 @@ public class TransportSearchActionTests extends ESTestCase {
             source.size(originalSize);
             int trackTotalHitsUpTo = randomIntBetween(0, Integer.MAX_VALUE);
             source.trackTotalHitsUpTo(trackTotalHitsUpTo);
-            SearchResponseMerger merger = TransportSearchAction.createSearchResponseMerger(source, timeProvider, reduceContext);
+            SearchResponseMerger merger = TransportSearchAction.createSearchResponseMerger(source, timeProvider, aggReduceContextBuilder());
             assertEquals(0, source.from());
             assertEquals(originalFrom + originalSize, source.size());
             assertEquals(trackTotalHitsUpTo, (int)source.trackTotalHitsUpTo());
@@ -837,4 +837,18 @@ public class TransportSearchActionTests extends ESTestCase {
             assertFalse(TransportSearchAction.shouldMinimizeRoundtrips(searchRequest));
         }
     }
+
+    private InternalAggregation.ReduceContextBuilder aggReduceContextBuilder() {
+        return new InternalAggregation.ReduceContextBuilder() {
+            @Override
+            public InternalAggregation.ReduceContext forPartialReduction() {
+                return InternalAggregation.ReduceContext.forPartialReduction(null, null);
+            };
+
+            @Override
+            public InternalAggregation.ReduceContext forFinalReduction() {
+                return InternalAggregation.ReduceContext.forFinalReduction(null, null, b -> {}, new PipelineTree(emptyMap(), emptyList()));
+            }
+        };
+    }
 }

+ 4 - 3
server/src/test/java/org/elasticsearch/search/SearchServiceTests.java

@@ -777,14 +777,15 @@ public class SearchServiceTests extends ESSingleNodeTestCase {
     }
 
     public void testCreateReduceContext() {
-        final SearchService service = getInstanceFromNode(SearchService.class);
+        SearchService service = getInstanceFromNode(SearchService.class);
+        InternalAggregation.ReduceContextBuilder reduceContextBuilder = service.aggReduceContextBuilder(new SearchRequest());
         {
-            InternalAggregation.ReduceContext reduceContext = service.createReduceContext(true);
+            InternalAggregation.ReduceContext reduceContext = reduceContextBuilder.forFinalReduction();
             expectThrows(MultiBucketConsumerService.TooManyBucketsException.class,
                 () -> reduceContext.consumeBucketsAndMaybeBreak(MultiBucketConsumerService.DEFAULT_MAX_BUCKETS + 1));
         }
         {
-            InternalAggregation.ReduceContext reduceContext = service.createReduceContext(false);
+            InternalAggregation.ReduceContext reduceContext = reduceContextBuilder.forPartialReduction();
             reduceContext.consumeBucketsAndMaybeBreak(MultiBucketConsumerService.DEFAULT_MAX_BUCKETS + 1);
         }
     }

+ 14 - 0
server/src/test/java/org/elasticsearch/search/aggregations/AggregatorFactoriesTests.java

@@ -36,15 +36,19 @@ import org.elasticsearch.script.Script;
 import org.elasticsearch.search.SearchModule;
 import org.elasticsearch.search.aggregations.bucket.filter.FilterAggregationBuilder;
 import org.elasticsearch.search.aggregations.pipeline.BucketScriptPipelineAggregationBuilder;
+import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
+import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator.PipelineTree;
 import org.elasticsearch.test.AbstractQueryTestCase;
 import org.elasticsearch.test.ESTestCase;
 
 import java.util.Collection;
+import java.util.List;
 import java.util.Random;
 import java.util.regex.Matcher;
 import java.util.regex.Pattern;
 
 import static java.util.Collections.emptyList;
+import static java.util.stream.Collectors.toList;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.instanceOf;
@@ -228,6 +232,16 @@ public class AggregatorFactoriesTests extends ESTestCase {
         assertSame(rewritten, secondRewritten);
     }
 
+    public void testBuildPipelineTreeResolvesPipelineOrder() {
+        AggregatorFactories.Builder builder = new AggregatorFactories.Builder();
+        builder.addPipelineAggregator(PipelineAggregatorBuilders.avgBucket("bar", "foo"));
+        builder.addPipelineAggregator(PipelineAggregatorBuilders.avgBucket("foo", "real"));
+        builder.addAggregator(AggregationBuilders.avg("real").field("target"));
+        PipelineTree tree = builder.buildPipelineTree();
+        assertThat(tree.aggregators().stream().map(PipelineAggregator::name).collect(toList()),
+                equalTo(List.of("foo", "bar")));
+    }
+
     @Override
     protected NamedXContentRegistry xContentRegistry() {
         return xContentRegistry;

+ 1 - 1
server/src/test/java/org/elasticsearch/search/aggregations/BasePipelineAggregationTestCase.java

@@ -96,7 +96,7 @@ public abstract class BasePipelineAggregationTestCase<AF extends AbstractPipelin
      */
     public void testFromXContent() throws IOException {
         AF testAgg = createTestAggregatorFactory();
-        AggregatorFactories.Builder factoriesBuilder = AggregatorFactories.builder().skipResolveOrder().addPipelineAggregator(testAgg);
+        AggregatorFactories.Builder factoriesBuilder = AggregatorFactories.builder().addPipelineAggregator(testAgg);
         logger.info("Content string: {}", factoriesBuilder);
         XContentBuilder builder = XContentFactory.contentBuilder(randomFrom(XContentType.values()));
         if (randomBoolean()) {

+ 14 - 13
server/src/test/java/org/elasticsearch/search/aggregations/InternalAggregationsTests.java

@@ -24,6 +24,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.util.BigArrays;
 import org.elasticsearch.search.DocValueFormat;
 import org.elasticsearch.search.SearchModule;
 import org.elasticsearch.search.aggregations.bucket.histogram.InternalDateHistogramTests;
@@ -32,9 +33,11 @@ import org.elasticsearch.search.aggregations.bucket.terms.StringTermsTests;
 import org.elasticsearch.search.aggregations.pipeline.AvgBucketPipelineAggregationBuilder;
 import org.elasticsearch.search.aggregations.pipeline.InternalSimpleValueTests;
 import org.elasticsearch.search.aggregations.pipeline.MaxBucketPipelineAggregationBuilder;
+import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator.PipelineTree;
 import org.elasticsearch.search.aggregations.pipeline.SiblingPipelineAggregator;
 import org.elasticsearch.search.aggregations.pipeline.SumBucketPipelineAggregationBuilder;
 import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.InternalAggregationTestCase;
 import org.elasticsearch.test.VersionUtils;
 
 import java.io.IOException;
@@ -42,6 +45,10 @@ import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
 
+import static java.util.Collections.emptyList;
+import static java.util.Collections.emptyMap;
+import static java.util.Collections.singletonList;
+
 public class InternalAggregationsTests extends ESTestCase {
 
     private final NamedWriteableRegistry registry = new NamedWriteableRegistry(
@@ -49,7 +56,8 @@ public class InternalAggregationsTests extends ESTestCase {
 
     public void testReduceEmptyAggs() {
         List<InternalAggregations> aggs = Collections.emptyList();
-        InternalAggregation.ReduceContext reduceContext = new InternalAggregation.ReduceContext(null, null, randomBoolean());
+        InternalAggregation.ReduceContextBuilder builder = InternalAggregationTestCase.emptyReduceContextBuilder();
+        InternalAggregation.ReduceContext reduceContext = randomBoolean() ? builder.forFinalReduction() : builder.forPartialReduction();
         assertNull(InternalAggregations.reduce(aggs, reduceContext));
     }
 
@@ -61,7 +69,7 @@ public class InternalAggregationsTests extends ESTestCase {
         topLevelPipelineAggs.add((SiblingPipelineAggregator)maxBucketPipelineAggregationBuilder.create());
         List<InternalAggregations> aggs = Collections.singletonList(new InternalAggregations(Collections.singletonList(terms),
             topLevelPipelineAggs));
-        InternalAggregation.ReduceContext reduceContext = new InternalAggregation.ReduceContext(null, null, false);
+        InternalAggregation.ReduceContext reduceContext = InternalAggregationTestCase.emptyReduceContextBuilder().forPartialReduction();
         InternalAggregations reducedAggs = InternalAggregations.topLevelReduce(aggs, reduceContext);
         assertEquals(1, reducedAggs.getTopLevelPipelineAggregators().size());
         assertEquals(1, reducedAggs.aggregations.size());
@@ -73,17 +81,10 @@ public class InternalAggregationsTests extends ESTestCase {
 
         MaxBucketPipelineAggregationBuilder maxBucketPipelineAggregationBuilder = new MaxBucketPipelineAggregationBuilder("test", "test");
         SiblingPipelineAggregator siblingPipelineAggregator = (SiblingPipelineAggregator) maxBucketPipelineAggregationBuilder.create();
-        InternalAggregation.ReduceContext reduceContext = new InternalAggregation.ReduceContext(null, null, true);
-        final InternalAggregations reducedAggs;
-        if (randomBoolean()) {
-            InternalAggregations aggs = new InternalAggregations(Collections.singletonList(terms),
-                Collections.singletonList(siblingPipelineAggregator));
-            reducedAggs = InternalAggregations.topLevelReduce(Collections.singletonList(aggs), reduceContext);
-        } else {
-            InternalAggregations aggs = new InternalAggregations(Collections.singletonList(terms),
-                Collections.singletonList(siblingPipelineAggregator));
-            reducedAggs = InternalAggregations.topLevelReduce(Collections.singletonList(aggs), reduceContext);
-        }
+        InternalAggregation.ReduceContext reduceContext = InternalAggregation.ReduceContext.forFinalReduction(
+                BigArrays.NON_RECYCLING_INSTANCE, null, b -> {}, new PipelineTree(emptyMap(), singletonList(siblingPipelineAggregator)));
+        InternalAggregations aggs = new InternalAggregations(Collections.singletonList(terms), emptyList());
+        InternalAggregations reducedAggs = InternalAggregations.topLevelReduce(Collections.singletonList(aggs), reduceContext);
         assertEquals(0, reducedAggs.getTopLevelPipelineAggregators().size());
         assertEquals(2, reducedAggs.aggregations.size());
     }

+ 2 - 7
server/src/test/java/org/elasticsearch/search/aggregations/bucket/composite/InternalCompositeTests.java

@@ -22,7 +22,6 @@ package org.elasticsearch.search.aggregations.bucket.composite;
 import org.apache.lucene.util.BytesRef;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.time.DateFormatter;
-import org.elasticsearch.common.util.BigArrays;
 import org.elasticsearch.index.mapper.DateFieldMapper;
 import org.elasticsearch.search.DocValueFormat;
 import org.elasticsearch.search.aggregations.InternalAggregation;
@@ -241,7 +240,7 @@ public class InternalCompositeTests extends InternalMultiBucketAggregationTestCa
         for (int i = 0; i < numSame; i++) {
             toReduce.add(result);
         }
-        InternalComposite finalReduce = (InternalComposite) result.reduce(toReduce, reduceContext());
+        InternalComposite finalReduce = (InternalComposite) result.reduce(toReduce, emptyReduceContextBuilder().forFinalReduction());
         assertThat(finalReduce.getBuckets().size(), equalTo(result.getBuckets().size()));
         Iterator<InternalComposite.InternalBucket> expectedIt = result.getBuckets().iterator();
         for (InternalComposite.InternalBucket bucket : finalReduce.getBuckets()) {
@@ -261,7 +260,7 @@ public class InternalCompositeTests extends InternalMultiBucketAggregationTestCa
                 rawFormats, emptyList(), null, reverseMuls, true, emptyList(), emptyMap());
         List<InternalAggregation> toReduce = Arrays.asList(unmapped, mapped);
         Collections.shuffle(toReduce, random());
-        InternalComposite finalReduce = (InternalComposite) unmapped.reduce(toReduce, reduceContext());
+        InternalComposite finalReduce = (InternalComposite) unmapped.reduce(toReduce, emptyReduceContextBuilder().forFinalReduction());
         assertThat(finalReduce.getBuckets().size(), equalTo(mapped.getBuckets().size()));
         if (false == mapped.getBuckets().isEmpty()) {
             assertThat(finalReduce.getFormats(), equalTo(mapped.getFormats()));
@@ -407,8 +406,4 @@ public class InternalCompositeTests extends InternalMultiBucketAggregationTestCa
             values
         );
     }
-
-    private InternalAggregation.ReduceContext reduceContext() {
-        return new InternalAggregation.ReduceContext(BigArrays.NON_RECYCLING_INSTANCE, null, true);
-    }
 }

+ 3 - 2
server/src/test/java/org/elasticsearch/search/aggregations/bucket/histogram/InternalHistogramTests.java

@@ -23,10 +23,10 @@ import org.apache.lucene.util.TestUtil;
 import org.elasticsearch.common.io.stream.Writeable.Reader;
 import org.elasticsearch.search.DocValueFormat;
 import org.elasticsearch.search.aggregations.BucketOrder;
-import org.elasticsearch.search.aggregations.InternalAggregation;
 import org.elasticsearch.search.aggregations.InternalAggregations;
 import org.elasticsearch.search.aggregations.ParsedMultiBucketAggregation;
 import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
+import org.elasticsearch.test.InternalAggregationTestCase;
 import org.elasticsearch.test.InternalMultiBucketAggregationTestCase;
 
 import java.util.ArrayList;
@@ -109,7 +109,8 @@ public class InternalHistogramTests extends InternalMultiBucketAggregationTestCa
         newBuckets.add(new InternalHistogram.Bucket(Double.NaN, b.docCount, keyed, b.format, b.aggregations));
         
         InternalHistogram newHistogram = histogram.create(newBuckets);
-        newHistogram.reduce(Arrays.asList(newHistogram, histogram2), new InternalAggregation.ReduceContext(null, null, false));
+        newHistogram.reduce(Arrays.asList(newHistogram, histogram2),
+                InternalAggregationTestCase.emptyReduceContextBuilder().forPartialReduction());
     }
 
     @Override

+ 2 - 1
server/src/test/java/org/elasticsearch/search/aggregations/bucket/significant/SignificanceHeuristicTests.java

@@ -48,6 +48,7 @@ import org.elasticsearch.search.aggregations.bucket.significant.heuristics.Perce
 import org.elasticsearch.search.aggregations.bucket.significant.heuristics.SignificanceHeuristic;
 import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
 import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.InternalAggregationTestCase;
 import org.elasticsearch.test.TestSearchContext;
 
 import java.io.ByteArrayInputStream;
@@ -147,7 +148,7 @@ public class SignificanceHeuristicTests extends ESTestCase {
 
     public void testReduce() {
         List<InternalAggregation> aggs = createInternalAggregations();
-        InternalAggregation.ReduceContext context = new InternalAggregation.ReduceContext(null, null, true);
+        InternalAggregation.ReduceContext context = InternalAggregationTestCase.emptyReduceContextBuilder().forFinalReduction();
         SignificantTerms reducedAgg = (SignificantTerms) aggs.get(0).reduce(aggs, context);
         assertThat(reducedAgg.getBuckets().size(), equalTo(2));
         assertThat(reducedAgg.getBuckets().get(0).getSubsetDf(), equalTo(8L));

+ 4 - 3
server/src/test/java/org/elasticsearch/search/aggregations/bucket/terms/TermsAggregatorTests.java

@@ -79,6 +79,7 @@ import org.elasticsearch.search.aggregations.bucket.nested.NestedAggregationBuil
 import org.elasticsearch.search.aggregations.metrics.InternalTopHits;
 import org.elasticsearch.search.aggregations.metrics.TopHitsAggregationBuilder;
 import org.elasticsearch.search.aggregations.pipeline.BucketScriptPipelineAggregationBuilder;
+import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator.PipelineTree;
 import org.elasticsearch.search.aggregations.support.AggregationInspectionHelper;
 import org.elasticsearch.search.aggregations.support.CoreValuesSourceType;
 import org.elasticsearch.search.aggregations.support.ValueType;
@@ -1084,9 +1085,9 @@ public class TermsAggregatorTests extends AggregatorTestCase {
                 }
                 dir.close();
             }
-            InternalAggregation.ReduceContext ctx =
-                new InternalAggregation.ReduceContext(new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY),
-                    new NoneCircuitBreakerService()), null, true);
+            InternalAggregation.ReduceContext ctx = InternalAggregation.ReduceContext.forFinalReduction(
+                    new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()),
+                    null, b -> {}, PipelineTree.EMPTY);
             for (InternalAggregation internalAgg : aggs) {
                 InternalAggregation mergedAggs = internalAgg.reduce(aggs, ctx);
                 assertTrue(mergedAggs instanceof DoubleTerms);

+ 1 - 1
server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java

@@ -1358,7 +1358,7 @@ public class SnapshotResiliencyTests extends ESTestCase {
                     bigArrays, new FetchPhase(Collections.emptyList()), responseCollectorService, new NoneCircuitBreakerService());
                 actions.put(SearchAction.INSTANCE,
                     new TransportSearchAction(threadPool, transportService, searchService,
-                        searchTransportService, new SearchPhaseController(searchService::createReduceContext), clusterService,
+                        searchTransportService, new SearchPhaseController(searchService::aggReduceContextBuilder), clusterService,
                         actionFilters, indexNameExpressionResolver));
                 actions.put(RestoreSnapshotAction.INSTANCE,
                     new TransportRestoreSnapshotAction(transportService, clusterService, threadPool, restoreService, actionFilters,

+ 10 - 14
test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java

@@ -94,6 +94,7 @@ import org.elasticsearch.mock.orig.Mockito;
 import org.elasticsearch.script.ScriptService;
 import org.elasticsearch.search.aggregations.MultiBucketConsumerService.MultiBucketConsumer;
 import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
+import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator.PipelineTree;
 import org.elasticsearch.search.aggregations.support.CoreValuesSourceType;
 import org.elasticsearch.search.aggregations.support.ValuesSourceType;
 import org.elasticsearch.search.fetch.FetchPhase;
@@ -454,7 +455,8 @@ public abstract class AggregatorTestCase extends ESTestCase {
             }
         }
 
-        List<InternalAggregation> aggs = new ArrayList<> ();
+        PipelineTree pipelines = builder.buildPipelineTree();
+        List<InternalAggregation> aggs = new ArrayList<>();
         Query rewritten = searcher.rewrite(query);
         Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1f);
         MultiBucketConsumer bucketConsumer = new MultiBucketConsumer(maxBucket,
@@ -481,33 +483,27 @@ public abstract class AggregatorTestCase extends ESTestCase {
                 Collections.shuffle(aggs, random());
                 int r = randomIntBetween(1, toReduceSize);
                 List<InternalAggregation> toReduce = aggs.subList(0, r);
-                MultiBucketConsumer reduceBucketConsumer = new MultiBucketConsumer(maxBucket,
-                    new NoneCircuitBreakerService().getBreaker(CircuitBreaker.REQUEST));
-                InternalAggregation.ReduceContext context =
-                    new InternalAggregation.ReduceContext(root.context().bigArrays(), getMockScriptService(),
-                        reduceBucketConsumer, false);
+                InternalAggregation.ReduceContext context = InternalAggregation.ReduceContext.forPartialReduction(
+                        root.context().bigArrays(), getMockScriptService());
                 A reduced = (A) aggs.get(0).reduce(toReduce, context);
-                doAssertReducedMultiBucketConsumer(reduced, reduceBucketConsumer);
                 aggs = new ArrayList<>(aggs.subList(r, toReduceSize));
                 aggs.add(reduced);
             }
             // now do the final reduce
             MultiBucketConsumer reduceBucketConsumer = new MultiBucketConsumer(maxBucket,
                 new NoneCircuitBreakerService().getBreaker(CircuitBreaker.REQUEST));
-            InternalAggregation.ReduceContext context =
-                new InternalAggregation.ReduceContext(root.context().bigArrays(), getMockScriptService(), reduceBucketConsumer, true);
+            InternalAggregation.ReduceContext context = InternalAggregation.ReduceContext.forFinalReduction(
+                    root.context().bigArrays(), getMockScriptService(), reduceBucketConsumer, pipelines);
 
             @SuppressWarnings("unchecked")
             A internalAgg = (A) aggs.get(0).reduce(aggs, context);
 
             // materialize any parent pipelines
-            internalAgg = (A) internalAgg.reducePipelines(internalAgg, context);
+            internalAgg = (A) internalAgg.reducePipelines(internalAgg, context, pipelines);
 
             // materialize any sibling pipelines at top level
-            if (internalAgg.pipelineAggregators().size() > 0) {
-                for (PipelineAggregator pipelineAggregator : internalAgg.pipelineAggregators()) {
-                    internalAgg = (A) pipelineAggregator.reduce(internalAgg, context);
-                }
+            for (PipelineAggregator pipelineAggregator : pipelines.aggregators()) {
+                internalAgg = (A) pipelineAggregator.reduce(internalAgg, context);
             }
             doAssertReducedMultiBucketConsumer(internalAgg, reduceBucketConsumer);
             return internalAgg;

+ 24 - 7
test/framework/src/main/java/org/elasticsearch/test/InternalAggregationTestCase.java

@@ -25,6 +25,7 @@ import org.elasticsearch.common.breaker.CircuitBreaker;
 import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.util.BigArrays;
 import org.elasticsearch.common.util.MockBigArrays;
 import org.elasticsearch.common.util.MockPageCacheRecycler;
 import org.elasticsearch.common.xcontent.ContextParser;
@@ -42,6 +43,7 @@ import org.elasticsearch.search.aggregations.Aggregation;
 import org.elasticsearch.search.aggregations.InternalAggregation;
 import org.elasticsearch.search.aggregations.MultiBucketConsumerService.MultiBucketConsumer;
 import org.elasticsearch.search.aggregations.ParsedAggregation;
+import org.elasticsearch.search.aggregations.InternalAggregation.ReduceContext;
 import org.elasticsearch.search.aggregations.bucket.adjacency.AdjacencyMatrixAggregationBuilder;
 import org.elasticsearch.search.aggregations.bucket.adjacency.ParsedAdjacencyMatrix;
 import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregationBuilder;
@@ -137,6 +139,7 @@ import org.elasticsearch.search.aggregations.pipeline.ParsedStatsBucket;
 import org.elasticsearch.search.aggregations.pipeline.PercentilesBucketPipelineAggregationBuilder;
 import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
 import org.elasticsearch.search.aggregations.pipeline.StatsBucketPipelineAggregationBuilder;
+import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator.PipelineTree;
 
 import java.io.IOException;
 import java.util.ArrayList;
@@ -159,6 +162,24 @@ import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.lessThanOrEqualTo;
 
 public abstract class InternalAggregationTestCase<T extends InternalAggregation> extends AbstractWireSerializingTestCase<T> {
+    /**
+     * Builds an {@link InternalAggregation.ReduceContextBuilder} that is valid but empty.
+     */
+    public static InternalAggregation.ReduceContextBuilder emptyReduceContextBuilder() {
+        return new InternalAggregation.ReduceContextBuilder() {
+            @Override
+            public InternalAggregation.ReduceContext forPartialReduction() {
+                return InternalAggregation.ReduceContext.forPartialReduction(BigArrays.NON_RECYCLING_INSTANCE, null);
+            }
+
+            @Override
+            public ReduceContext forFinalReduction() {
+                return InternalAggregation.ReduceContext.forFinalReduction(
+                        BigArrays.NON_RECYCLING_INSTANCE, null, b -> {}, PipelineTree.EMPTY);
+            }
+        };
+    }
+
     public static final int DEFAULT_MAX_BUCKETS = 100000;
     protected static final double TOLERANCE = 1e-10;
 
@@ -270,10 +291,7 @@ public abstract class InternalAggregationTestCase<T extends InternalAggregation>
             Collections.shuffle(toReduce, random());
             int r = randomIntBetween(1, toReduceSize);
             List<InternalAggregation> internalAggregations = toReduce.subList(0, r);
-            MultiBucketConsumer bucketConsumer = new MultiBucketConsumer(DEFAULT_MAX_BUCKETS,
-                new NoneCircuitBreakerService().getBreaker(CircuitBreaker.REQUEST));
-            InternalAggregation.ReduceContext context =
-                new InternalAggregation.ReduceContext(bigArrays, mockScriptService, bucketConsumer,false);
+            InternalAggregation.ReduceContext context = InternalAggregation.ReduceContext.forPartialReduction(bigArrays, mockScriptService);
             @SuppressWarnings("unchecked")
             T reduced = (T) inputs.get(0).reduce(internalAggregations, context);
             int initialBucketCount = 0;
@@ -283,14 +301,13 @@ public abstract class InternalAggregationTestCase<T extends InternalAggregation>
             int reducedBucketCount = countInnerBucket(reduced);
             //check that non final reduction never adds buckets
             assertThat(reducedBucketCount, lessThanOrEqualTo(initialBucketCount));
-            assertMultiBucketConsumer(reducedBucketCount, bucketConsumer);
             toReduce = new ArrayList<>(toReduce.subList(r, toReduceSize));
             toReduce.add(reduced);
         }
         MultiBucketConsumer bucketConsumer = new MultiBucketConsumer(DEFAULT_MAX_BUCKETS,
             new NoneCircuitBreakerService().getBreaker(CircuitBreaker.REQUEST));
-        InternalAggregation.ReduceContext context =
-            new InternalAggregation.ReduceContext(bigArrays, mockScriptService, bucketConsumer, true);
+        InternalAggregation.ReduceContext context = InternalAggregation.ReduceContext.forFinalReduction(
+                bigArrays, mockScriptService, bucketConsumer, PipelineTree.EMPTY);
         @SuppressWarnings("unchecked")
         T reduced = (T) inputs.get(0).reduce(toReduce, context);
         assertMultiBucketConsumer(reduced, bucketConsumer);

+ 7 - 7
x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java

@@ -21,7 +21,7 @@ import org.elasticsearch.common.unit.TimeValue;
 import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
 import org.elasticsearch.search.SearchHits;
 import org.elasticsearch.search.SearchShardTarget;
-import org.elasticsearch.search.aggregations.InternalAggregation.ReduceContext;
+import org.elasticsearch.search.aggregations.InternalAggregation;
 import org.elasticsearch.search.aggregations.InternalAggregations;
 import org.elasticsearch.search.internal.InternalSearchResponse;
 import org.elasticsearch.tasks.TaskId;
@@ -45,7 +45,7 @@ final class AsyncSearchTask extends SearchTask {
     private final AsyncSearchId searchId;
     private final Client client;
     private final ThreadPool threadPool;
-    private final Supplier<ReduceContext> reduceContextSupplier;
+    private final Supplier<InternalAggregation.ReduceContext> aggReduceContextSupplier;
     private final Listener progressListener;
 
     private final Map<String, String> originHeaders;
@@ -72,7 +72,7 @@ final class AsyncSearchTask extends SearchTask {
      * @param taskHeaders The filtered request headers for the task.
      * @param searchId The {@link AsyncSearchId} of the task.
      * @param threadPool The threadPool to schedule runnable.
-     * @param reduceContextSupplier A supplier to create final reduce contexts.
+     * @param aggReduceContextSupplier A supplier to create final reduce contexts.
      */
     AsyncSearchTask(long id,
                     String type,
@@ -84,14 +84,14 @@ final class AsyncSearchTask extends SearchTask {
                     AsyncSearchId searchId,
                     Client client,
                     ThreadPool threadPool,
-                    Supplier<ReduceContext> reduceContextSupplier) {
+                    Supplier<InternalAggregation.ReduceContext> aggReduceContextSupplier) {
         super(id, type, action, "async_search", parentTaskId, taskHeaders);
         this.expirationTimeMillis = getStartTime() + keepAlive.getMillis();
         this.originHeaders = originHeaders;
         this.searchId = searchId;
         this.client = client;
         this.threadPool = threadPool;
-        this.reduceContextSupplier = reduceContextSupplier;
+        this.aggReduceContextSupplier = aggReduceContextSupplier;
         this.progressListener = new Listener();
         this.searchResponse = new AtomicReference<>();
         setProgressListener(progressListener);
@@ -327,7 +327,7 @@ final class AsyncSearchTask extends SearchTask {
             // best effort to cancel expired tasks
             checkExpiration();
             searchResponse.compareAndSet(null,
-                new MutableSearchResponse(shards.size() + skipped.size(), skipped.size(), clusters, reduceContextSupplier));
+                new MutableSearchResponse(shards.size() + skipped.size(), skipped.size(), clusters, aggReduceContextSupplier));
             executeInitListeners();
         }
 
@@ -360,7 +360,7 @@ final class AsyncSearchTask extends SearchTask {
             if (searchResponse.get() == null) {
                 // if the failure occurred before calling onListShards
                 searchResponse.compareAndSet(null,
-                    new MutableSearchResponse(-1, -1, null, reduceContextSupplier));
+                    new MutableSearchResponse(-1, -1, null, aggReduceContextSupplier));
             }
             searchResponse.get().updateWithFailure(exc);
             executeInitListeners();

+ 7 - 7
x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/MutableSearchResponse.java

@@ -14,12 +14,11 @@ import org.elasticsearch.action.search.ShardSearchFailure;
 import org.elasticsearch.common.unit.TimeValue;
 import org.elasticsearch.common.util.concurrent.AtomicArray;
 import org.elasticsearch.search.SearchHits;
-import org.elasticsearch.search.aggregations.InternalAggregation.ReduceContext;
+import org.elasticsearch.search.aggregations.InternalAggregation;
 import org.elasticsearch.search.aggregations.InternalAggregations;
 import org.elasticsearch.search.internal.InternalSearchResponse;
 import org.elasticsearch.xpack.core.search.action.AsyncSearchResponse;
 
-
 import java.util.ArrayList;
 import java.util.List;
 import java.util.function.Supplier;
@@ -39,7 +38,7 @@ class MutableSearchResponse {
     private final int skippedShards;
     private final Clusters clusters;
     private final AtomicArray<ShardSearchFailure> shardFailures;
-    private final Supplier<ReduceContext> reduceContextSupplier;
+    private final Supplier<InternalAggregation.ReduceContext> aggReduceContextSupplier;
 
     private int version;
     private boolean isPartial;
@@ -56,13 +55,14 @@ class MutableSearchResponse {
      * @param totalShards The number of shards that participate in the request, or -1 to indicate a failure.
      * @param skippedShards The number of skipped shards, or -1 to indicate a failure.
      * @param clusters The remote clusters statistics.
-     * @param reduceContextSupplier A supplier to run final reduce on partial aggregations.
+     * @param aggReduceContextSupplier A supplier to run final reduce on partial aggregations.
      */
-    MutableSearchResponse(int totalShards, int skippedShards, Clusters clusters, Supplier<ReduceContext> reduceContextSupplier) {
+    MutableSearchResponse(int totalShards, int skippedShards, Clusters clusters,
+            Supplier<InternalAggregation.ReduceContext> aggReduceContextSupplier) {
         this.totalShards = totalShards;
         this.skippedShards = skippedShards;
         this.clusters = clusters;
-        this.reduceContextSupplier = reduceContextSupplier;
+        this.aggReduceContextSupplier = aggReduceContextSupplier;
         this.version = 0;
         this.shardFailures = totalShards == -1 ? null : new AtomicArray<>(totalShards-skippedShards);
         this.isPartial = true;
@@ -136,7 +136,7 @@ class MutableSearchResponse {
         if (totalShards != -1) {
             if (sections.aggregations() != null && isFinalReduce == false) {
                 InternalAggregations oldAggs = (InternalAggregations) sections.aggregations();
-                InternalAggregations newAggs = topLevelReduce(singletonList(oldAggs), reduceContextSupplier.get());
+                InternalAggregations newAggs = topLevelReduce(singletonList(oldAggs), aggReduceContextSupplier.get());
                 sections = new InternalSearchResponse(sections.hits(), newAggs, sections.suggest(),
                     null, sections.timedOut(), sections.terminatedEarly(), sections.getNumReducePhases());
                 isFinalReduce = true;

+ 7 - 4
x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/TransportSubmitAsyncSearchAction.java

@@ -25,7 +25,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.unit.TimeValue;
 import org.elasticsearch.index.engine.DocumentMissingException;
 import org.elasticsearch.search.SearchService;
-import org.elasticsearch.search.aggregations.InternalAggregation.ReduceContext;
+import org.elasticsearch.search.aggregations.InternalAggregation;
 import org.elasticsearch.tasks.CancellableTask;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.tasks.TaskCancelledException;
@@ -36,13 +36,14 @@ import org.elasticsearch.xpack.core.search.action.SubmitAsyncSearchAction;
 import org.elasticsearch.xpack.core.search.action.SubmitAsyncSearchRequest;
 
 import java.util.Map;
+import java.util.function.Function;
 import java.util.function.Supplier;
 
 public class TransportSubmitAsyncSearchAction extends HandledTransportAction<SubmitAsyncSearchRequest, AsyncSearchResponse> {
     private static final Logger logger = LogManager.getLogger(TransportSubmitAsyncSearchAction.class);
 
     private final NodeClient nodeClient;
-    private final Supplier<ReduceContext> reduceContextSupplier;
+    private final Function<SearchRequest, InternalAggregation.ReduceContext> requestToAggReduceContextBuilder;
     private final TransportSearchAction searchAction;
     private final AsyncSearchIndexService store;
 
@@ -57,7 +58,7 @@ public class TransportSubmitAsyncSearchAction extends HandledTransportAction<Sub
                                             TransportSearchAction searchAction) {
         super(SubmitAsyncSearchAction.NAME, transportService, actionFilters, SubmitAsyncSearchRequest::new);
         this.nodeClient = nodeClient;
-        this.reduceContextSupplier = () -> searchService.createReduceContext(true);
+        this.requestToAggReduceContextBuilder = request -> searchService.aggReduceContextBuilder(request).forFinalReduction();
         this.searchAction = searchAction;
         this.store = new AsyncSearchIndexService(clusterService, transportService.getThreadPool().getThreadContext(), client, registry);
     }
@@ -135,8 +136,10 @@ public class TransportSubmitAsyncSearchAction extends HandledTransportAction<Sub
             @Override
             public AsyncSearchTask createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> taskHeaders) {
                 AsyncSearchId searchId = new AsyncSearchId(docID, new TaskId(nodeClient.getLocalNodeId(), id));
+                Supplier<InternalAggregation.ReduceContext> aggReduceContextSupplier =
+                        () -> requestToAggReduceContextBuilder.apply(request.getSearchRequest());
                 return new AsyncSearchTask(id, type, action, parentTaskId, keepAlive, originHeaders, taskHeaders, searchId,
-                    store.getClient(), nodeClient.threadPool(), reduceContextSupplier);
+                    store.getClient(), nodeClient.threadPool(), aggReduceContextSupplier);
             }
         };
         searchRequest.setParentTask(new TaskId(nodeClient.getLocalNodeId(), parentTaskId));

+ 7 - 4
x-pack/plugin/rollup/src/main/java/org/elasticsearch/xpack/rollup/RollupResponseTranslator.java

@@ -34,6 +34,7 @@ import org.elasticsearch.search.aggregations.metrics.InternalMin;
 import org.elasticsearch.search.aggregations.metrics.InternalNumericMetricsAggregation.SingleValue;
 import org.elasticsearch.search.aggregations.metrics.InternalSum;
 import org.elasticsearch.search.aggregations.metrics.SumAggregationBuilder;
+import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator.PipelineTree;
 import org.elasticsearch.search.internal.InternalSearchResponse;
 import org.elasticsearch.xpack.core.rollup.RollupField;
 
@@ -272,6 +273,8 @@ public class RollupResponseTranslator {
         // which means we can use aggregation's reduce method to combine, just as if
         // it was a result from another shard
         InternalAggregations currentTree = new InternalAggregations(Collections.emptyList());
+        InternalAggregation.ReduceContext finalReduceContext = InternalAggregation.ReduceContext.forFinalReduction(
+                reduceContext.bigArrays(), reduceContext.scriptService(), b -> {}, PipelineTree.EMPTY);
         for (SearchResponse rolledResponse : rolledResponses) {
             List<InternalAggregation> unrolledAggs = new ArrayList<>(rolledResponse.getAggregations().asList().size());
             for (Aggregation agg : rolledResponse.getAggregations()) {
@@ -289,14 +292,14 @@ public class RollupResponseTranslator {
             // Iteratively merge in each new set of unrolled aggs, so that we can identify/fix overlapping doc_counts
             // in the next round of unrolling
             InternalAggregations finalUnrolledAggs = new InternalAggregations(unrolledAggs);
-            currentTree = InternalAggregations.reduce(Arrays.asList(currentTree, finalUnrolledAggs),
-                    new InternalAggregation.ReduceContext(reduceContext.bigArrays(), reduceContext.scriptService(), true));
+            currentTree = InternalAggregations.reduce(Arrays.asList(currentTree, finalUnrolledAggs), finalReduceContext);
         }
 
         // Add in the live aggregations if they exist
         if (liveAggs.asList().size() != 0) {
-            currentTree = InternalAggregations.reduce(Arrays.asList(currentTree, liveAggs),
-                    new InternalAggregation.ReduceContext(reduceContext.bigArrays(), reduceContext.scriptService(), true));
+            // TODO it looks like this passes the "final" reduce context more than once.
+            // Once here and once in the for above. That is bound to cause trouble.
+            currentTree = InternalAggregations.reduce(Arrays.asList(currentTree, liveAggs), finalReduceContext);
         }
 
         return mergeFinalResponse(liveResponse, rolledResponses, currentTree);

+ 1 - 2
x-pack/plugin/rollup/src/main/java/org/elasticsearch/xpack/rollup/action/TransportRollupSearchAction.java

@@ -105,8 +105,7 @@ public class TransportRollupSearchAction extends TransportAction<SearchRequest,
         MultiSearchRequest msearch = createMSearchRequest(request, registry, rollupSearchContext);
 
         client.multiSearch(msearch, ActionListener.wrap(msearchResponse -> {
-            InternalAggregation.ReduceContext context
-                    = new InternalAggregation.ReduceContext(bigArrays, scriptService, false);
+            InternalAggregation.ReduceContext context = InternalAggregation.ReduceContext.forPartialReduction(bigArrays, scriptService);
             listener.onResponse(processResponses(rollupSearchContext, msearchResponse, context));
         }, listener::onFailure));
     }

+ 19 - 14
x-pack/plugin/rollup/src/test/java/org/elasticsearch/xpack/rollup/RollupResponseTranslationTests.java

@@ -66,6 +66,7 @@ import org.elasticsearch.search.aggregations.metrics.InternalSum;
 import org.elasticsearch.search.aggregations.metrics.MaxAggregationBuilder;
 import org.elasticsearch.search.aggregations.metrics.MinAggregationBuilder;
 import org.elasticsearch.search.aggregations.metrics.SumAggregationBuilder;
+import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator.PipelineTree;
 import org.elasticsearch.search.aggregations.support.ValueType;
 import org.elasticsearch.search.internal.InternalSearchResponse;
 import org.elasticsearch.xpack.core.rollup.RollupField;
@@ -96,12 +97,12 @@ public class RollupResponseTranslationTests extends AggregatorTestCase {
 
         Exception e = expectThrows(RuntimeException.class,
                 () -> RollupResponseTranslator.combineResponses(failure,
-                        new InternalAggregation.ReduceContext(bigArrays, scriptService, true)));
+                        InternalAggregation.ReduceContext.forFinalReduction(bigArrays, scriptService, b -> {}, PipelineTree.EMPTY)));
         assertThat(e.getMessage(), equalTo("foo"));
 
         e = expectThrows(RuntimeException.class,
                 () -> RollupResponseTranslator.translateResponse(failure,
-                        new InternalAggregation.ReduceContext(bigArrays, scriptService, true)));
+                        InternalAggregation.ReduceContext.forFinalReduction(bigArrays, scriptService, b -> {}, PipelineTree.EMPTY)));
         assertThat(e.getMessage(), equalTo("foo"));
 
         e = expectThrows(RuntimeException.class,
@@ -118,7 +119,7 @@ public class RollupResponseTranslationTests extends AggregatorTestCase {
 
         Exception e = expectThrows(RuntimeException.class,
                 () -> RollupResponseTranslator.translateResponse(failure,
-                        new InternalAggregation.ReduceContext(bigArrays, scriptService, true)));
+                        InternalAggregation.ReduceContext.forFinalReduction(bigArrays, scriptService, b -> {}, PipelineTree.EMPTY)));
         assertThat(e.getMessage(), equalTo("rollup failure"));
     }
 
@@ -132,7 +133,7 @@ public class RollupResponseTranslationTests extends AggregatorTestCase {
 
         ResourceNotFoundException e = expectThrows(ResourceNotFoundException.class,
                 () -> RollupResponseTranslator.combineResponses(failure,
-                        new InternalAggregation.ReduceContext(bigArrays, scriptService, true)));
+                        InternalAggregation.ReduceContext.forFinalReduction(bigArrays, scriptService, b -> {}, PipelineTree.EMPTY)));
         assertThat(e.getMessage(), equalTo("Index [[foo]] was not found, likely because it was deleted while the request was in-flight. " +
             "Rollup does not support partial search results, please try the request again."));
     }
@@ -177,7 +178,7 @@ public class RollupResponseTranslationTests extends AggregatorTestCase {
         ScriptService scriptService = mock(ScriptService.class);
 
         ResourceNotFoundException e = expectThrows(ResourceNotFoundException.class, () -> RollupResponseTranslator.combineResponses(msearch,
-                new InternalAggregation.ReduceContext(bigArrays, scriptService, true)));
+                InternalAggregation.ReduceContext.forFinalReduction(bigArrays, scriptService, b -> {}, PipelineTree.EMPTY)));
         assertThat(e.getMessage(), equalTo("Index [[foo]] was not found, likely because it was deleted while the request was in-flight. " +
             "Rollup does not support partial search results, please try the request again."));
     }
@@ -196,7 +197,7 @@ public class RollupResponseTranslationTests extends AggregatorTestCase {
         ScriptService scriptService = mock(ScriptService.class);
 
         SearchResponse response = RollupResponseTranslator.translateResponse(msearch,
-            new InternalAggregation.ReduceContext(bigArrays, scriptService, true));
+                InternalAggregation.ReduceContext.forFinalReduction(bigArrays, scriptService, b -> {}, PipelineTree.EMPTY));
         assertNotNull(response);
         Aggregations responseAggs = response.getAggregations();
         assertThat(responseAggs.asList().size(), equalTo(0));
@@ -213,7 +214,7 @@ public class RollupResponseTranslationTests extends AggregatorTestCase {
         ScriptService scriptService = mock(ScriptService.class);
 
         ResourceNotFoundException e = expectThrows(ResourceNotFoundException.class, () -> RollupResponseTranslator.combineResponses(msearch,
-            new InternalAggregation.ReduceContext(bigArrays, scriptService, true)));
+                InternalAggregation.ReduceContext.forFinalReduction(bigArrays, scriptService, b -> {}, PipelineTree.EMPTY)));
         assertThat(e.getMessage(), equalTo("Index [[foo]] was not found, likely because it was deleted while the request was in-flight. " +
             "Rollup does not support partial search results, please try the request again."));
     }
@@ -268,7 +269,8 @@ public class RollupResponseTranslationTests extends AggregatorTestCase {
 
         BigArrays bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService());
         ScriptService scriptService = mock(ScriptService.class);
-        InternalAggregation.ReduceContext context = new InternalAggregation.ReduceContext(bigArrays, scriptService, true);
+        InternalAggregation.ReduceContext context = InternalAggregation.ReduceContext.forFinalReduction(
+                bigArrays, scriptService, b -> {}, PipelineTree.EMPTY);
 
         SearchResponse finalResponse = RollupResponseTranslator.translateResponse(new MultiSearchResponse.Item[]{item}, context);
         assertNotNull(finalResponse);
@@ -282,7 +284,8 @@ public class RollupResponseTranslationTests extends AggregatorTestCase {
         MultiSearchResponse.Item missing = new MultiSearchResponse.Item(null, new IndexNotFoundException("foo"));
         BigArrays bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService());
         ScriptService scriptService = mock(ScriptService.class);
-        InternalAggregation.ReduceContext context = new InternalAggregation.ReduceContext(bigArrays, scriptService, true);
+        InternalAggregation.ReduceContext context = InternalAggregation.ReduceContext.forFinalReduction(
+                bigArrays, scriptService, b -> {}, PipelineTree.EMPTY);
 
         ResourceNotFoundException e = expectThrows(ResourceNotFoundException.class,
                 () -> RollupResponseTranslator.translateResponse(new MultiSearchResponse.Item[]{missing}, context));
@@ -316,7 +319,7 @@ public class RollupResponseTranslationTests extends AggregatorTestCase {
 
         Exception e = expectThrows(RuntimeException.class,
                 () -> RollupResponseTranslator.combineResponses(msearch,
-                        new InternalAggregation.ReduceContext(bigArrays, scriptService, true)));
+                        InternalAggregation.ReduceContext.forFinalReduction(bigArrays, scriptService, b -> {}, PipelineTree.EMPTY)));
         assertThat(e.getMessage(), containsString("Expected [bizzbuzz] to be a FilterAggregation"));
     }
 
@@ -345,7 +348,7 @@ public class RollupResponseTranslationTests extends AggregatorTestCase {
 
         Exception e = expectThrows(RuntimeException.class,
                 () -> RollupResponseTranslator.combineResponses(msearch,
-                        new InternalAggregation.ReduceContext(bigArrays, scriptService, true)));
+                        InternalAggregation.ReduceContext.forFinalReduction(bigArrays, scriptService, b -> {}, PipelineTree.EMPTY)));
         assertThat(e.getMessage(),
                 equalTo("Expected [filter_foo] to be a FilterAggregation, but was [InternalMax]"));
     }
@@ -399,7 +402,7 @@ public class RollupResponseTranslationTests extends AggregatorTestCase {
 
 
         SearchResponse response = RollupResponseTranslator.combineResponses(msearch,
-                new InternalAggregation.ReduceContext(bigArrays, scriptService, true));
+                InternalAggregation.ReduceContext.forFinalReduction(bigArrays, scriptService, b -> {}, PipelineTree.EMPTY));
         assertNotNull(response);
         Aggregations responseAggs = response.getAggregations();
         assertNotNull(responseAggs);
@@ -507,7 +510,8 @@ public class RollupResponseTranslationTests extends AggregatorTestCase {
 
         BigArrays bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService());
         ScriptService scriptService = mock(ScriptService.class);
-        InternalAggregation.ReduceContext reduceContext = new InternalAggregation.ReduceContext(bigArrays, scriptService, true);
+        InternalAggregation.ReduceContext reduceContext = InternalAggregation.ReduceContext.forFinalReduction(
+                bigArrays, scriptService, b -> {}, PipelineTree.EMPTY);
         ClassCastException e = expectThrows(ClassCastException.class,
                 () -> RollupResponseTranslator.combineResponses(msearch, reduceContext));
         assertThat(e.getMessage(),
@@ -608,7 +612,8 @@ public class RollupResponseTranslationTests extends AggregatorTestCase {
         // Reduce the InternalDateHistogram response so we can fill buckets
         BigArrays bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService());
         ScriptService scriptService = mock(ScriptService.class);
-        InternalAggregation.ReduceContext context = new InternalAggregation.ReduceContext(bigArrays, scriptService, true);
+        InternalAggregation.ReduceContext context = InternalAggregation.ReduceContext.forFinalReduction(
+                bigArrays, scriptService, b -> {}, PipelineTree.EMPTY);
 
         InternalAggregation reduced = ((InternalDateHistogram)unrolled).reduce(Collections.singletonList(unrolled), context);
         assertThat(reduced.toString(), equalTo("{\"histo\":{\"buckets\":[{\"key_as_string\":\"1970-01-01T00:00:00.100Z\",\"key\":100," +