浏览代码

Release aggregations earlier during reduce (#124520)

Release each hit's aggregations before moving on to the next hit and unlink it from the shard result even earlier.
Also, do the aggregation-reduce earlier in the reduce steps to reduce average heap use over time. 
To that effect, do not do the reduction in the search phase controller. This has the added benefit of removing any need for a fake aggs-reduce-context in scroll.
Armin Braun 7 月之前
父节点
当前提交
70486b45e6

+ 66 - 31
server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java

@@ -15,9 +15,12 @@ import org.apache.lucene.search.TopDocs;
 import org.elasticsearch.action.search.SearchPhaseController.TopDocsStats;
 import org.elasticsearch.common.breaker.CircuitBreaker;
 import org.elasticsearch.common.breaker.CircuitBreakingException;
+import org.elasticsearch.common.collect.Iterators;
 import org.elasticsearch.common.io.stream.DelayableWriteable;
 import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
 import org.elasticsearch.common.util.concurrent.AbstractRunnable;
+import org.elasticsearch.core.Releasable;
+import org.elasticsearch.core.Releasables;
 import org.elasticsearch.search.SearchPhaseResult;
 import org.elasticsearch.search.SearchService;
 import org.elasticsearch.search.SearchShardTarget;
@@ -31,6 +34,7 @@ import java.util.ArrayDeque;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.Comparator;
+import java.util.Iterator;
 import java.util.List;
 import java.util.concurrent.Executor;
 import java.util.concurrent.atomic.AtomicReference;
@@ -174,14 +178,10 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhas
         this.mergeResult = null;
         final int resultSize = buffer.size() + (mergeResult == null ? 0 : 1);
         final List<TopDocs> topDocsList = hasTopDocs ? new ArrayList<>(resultSize) : null;
-        final List<DelayableWriteable<InternalAggregations>> aggsList = hasAggs ? new ArrayList<>(resultSize) : null;
         if (mergeResult != null) {
             if (topDocsList != null) {
                 topDocsList.add(mergeResult.reducedTopDocs);
             }
-            if (aggsList != null) {
-                aggsList.add(DelayableWriteable.referencing(mergeResult.reducedAggs));
-            }
         }
         for (QuerySearchResult result : buffer) {
             topDocsStats.add(result.topDocs(), result.searchTimedOut(), result.terminatedEarly());
@@ -190,34 +190,39 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhas
                 setShardIndex(topDocs.topDocs, result.getShardIndex());
                 topDocsList.add(topDocs.topDocs);
             }
-            if (aggsList != null) {
-                aggsList.add(result.getAggs());
-            }
         }
         SearchPhaseController.ReducedQueryPhase reducePhase;
         long breakerSize = circuitBreakerBytes;
+        final InternalAggregations aggs;
         try {
-            if (aggsList != null) {
+            if (hasAggs) {
                 // Add an estimate of the final reduce size
                 breakerSize = addEstimateAndMaybeBreak(estimateRamBytesUsedForReduce(breakerSize));
+                aggs = aggregate(
+                    buffer.iterator(),
+                    mergeResult,
+                    resultSize,
+                    performFinalReduce ? aggReduceContextBuilder.forFinalReduction() : aggReduceContextBuilder.forPartialReduction()
+                );
+            } else {
+                aggs = null;
             }
             reducePhase = SearchPhaseController.reducedQueryPhase(
                 results.asList(),
-                aggsList,
+                aggs,
                 topDocsList == null ? Collections.emptyList() : topDocsList,
                 topDocsStats,
                 numReducePhases,
                 false,
-                aggReduceContextBuilder,
-                queryPhaseRankCoordinatorContext,
-                performFinalReduce
+                queryPhaseRankCoordinatorContext
             );
+            buffer = null;
         } finally {
             releaseAggs(buffer);
         }
         if (hasAggs
             // reduced aggregations can be null if all shards failed
-            && reducePhase.aggregations() != null) {
+            && aggs != null) {
 
             // Update the circuit breaker to replace the estimation with the serialized size of the newly reduced result
             long finalSize = DelayableWriteable.getSerializedSize(reducePhase.aggregations()) - breakerSize;
@@ -249,17 +254,7 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhas
         toConsume.sort(RESULT_COMPARATOR);
 
         final TopDocs newTopDocs;
-        final InternalAggregations newAggs;
-        final List<DelayableWriteable<InternalAggregations>> aggsList;
         final int resultSetSize = toConsume.size() + (lastMerge != null ? 1 : 0);
-        if (hasAggs) {
-            aggsList = new ArrayList<>(resultSetSize);
-            if (lastMerge != null) {
-                aggsList.add(DelayableWriteable.referencing(lastMerge.reducedAggs));
-            }
-        } else {
-            aggsList = null;
-        }
         List<TopDocs> topDocsList;
         if (hasTopDocs) {
             topDocsList = new ArrayList<>(resultSetSize);
@@ -269,14 +264,12 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhas
         } else {
             topDocsList = null;
         }
+        final InternalAggregations newAggs;
         try {
             for (QuerySearchResult result : toConsume) {
                 topDocsStats.add(result.topDocs(), result.searchTimedOut(), result.terminatedEarly());
                 SearchShardTarget target = result.getSearchShardTarget();
                 processedShards.add(new SearchShard(target.getClusterAlias(), target.getShardId()));
-                if (aggsList != null) {
-                    aggsList.add(result.getAggs());
-                }
                 if (topDocsList != null) {
                     TopDocsAndMaxScore topDocs = result.consumeTopDocs();
                     setShardIndex(topDocs.topDocs, result.getShardIndex());
@@ -285,9 +278,10 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhas
             }
             // we have to merge here in the same way we collect on a shard
             newTopDocs = topDocsList == null ? null : mergeTopDocs(topDocsList, topNSize, 0);
-            newAggs = aggsList == null
-                ? null
-                : InternalAggregations.topLevelReduceDelayable(aggsList, aggReduceContextBuilder.forPartialReduction());
+            newAggs = hasAggs
+                ? aggregate(toConsume.iterator(), lastMerge, resultSetSize, aggReduceContextBuilder.forPartialReduction())
+                : null;
+            toConsume = null;
         } finally {
             releaseAggs(toConsume);
         }
@@ -302,6 +296,45 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhas
         return new MergeResult(processedShards, newTopDocs, newAggs, newAggs != null ? DelayableWriteable.getSerializedSize(newAggs) : 0);
     }
 
+    private static InternalAggregations aggregate(
+        Iterator<QuerySearchResult> toConsume,
+        MergeResult lastMerge,
+        int resultSetSize,
+        AggregationReduceContext reduceContext
+    ) {
+        interface ReleasableIterator extends Iterator<InternalAggregations>, Releasable {}
+        try (var aggsIter = new ReleasableIterator() {
+
+            private Releasable toRelease;
+
+            @Override
+            public void close() {
+                Releasables.close(toRelease);
+            }
+
+            @Override
+            public boolean hasNext() {
+                return toConsume.hasNext();
+            }
+
+            @Override
+            public InternalAggregations next() {
+                var res = toConsume.next().consumeAggs();
+                Releasables.close(toRelease);
+                toRelease = res;
+                return res.expand();
+            }
+        }) {
+            return InternalAggregations.topLevelReduce(
+                lastMerge == null ? aggsIter : Iterators.concat(Iterators.single(lastMerge.reducedAggs), aggsIter),
+                resultSetSize,
+                reduceContext
+            );
+        } finally {
+            toConsume.forEachRemaining(QuerySearchResult::releaseAggs);
+        }
+    }
+
     public int getNumReducePhases() {
         return numReducePhases;
     }
@@ -517,8 +550,10 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhas
     }
 
     private static void releaseAggs(List<QuerySearchResult> toConsume) {
-        for (QuerySearchResult result : toConsume) {
-            result.releaseAggs();
+        if (toConsume != null) {
+            for (QuerySearchResult result : toConsume) {
+                result.releaseAggs();
+            }
         }
     }
 

+ 4 - 12
server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java

@@ -20,7 +20,6 @@ import org.apache.lucene.search.TopFieldDocs;
 import org.apache.lucene.search.TotalHits;
 import org.apache.lucene.search.TotalHits.Relation;
 import org.elasticsearch.common.breaker.CircuitBreaker;
-import org.elasticsearch.common.io.stream.DelayableWriteable;
 import org.elasticsearch.common.lucene.Lucene;
 import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
 import org.elasticsearch.common.util.Maps;
@@ -401,7 +400,7 @@ public final class SearchPhaseController {
     /**
      * Reduces the given query results and consumes all aggregations and profile results.
      * @param queryResults a list of non-null query shard results
-     * @param bufferedAggs a list of pre-collected aggregations.
+     * @param reducedAggs already reduced aggregations
      * @param bufferedTopDocs a list of pre-collected top docs.
      * @param numReducePhases the number of non-final reduce phases applied to the query results.
      * @see QuerySearchResult#getAggs()
@@ -409,14 +408,12 @@ public final class SearchPhaseController {
      */
     static ReducedQueryPhase reducedQueryPhase(
         Collection<? extends SearchPhaseResult> queryResults,
-        @Nullable List<DelayableWriteable<InternalAggregations>> bufferedAggs,
+        @Nullable InternalAggregations reducedAggs,
         List<TopDocs> bufferedTopDocs,
         TopDocsStats topDocsStats,
         int numReducePhases,
         boolean isScrollRequest,
-        AggregationReduceContext.Builder aggReduceContextBuilder,
-        QueryPhaseRankCoordinatorContext queryPhaseRankCoordinatorContext,
-        boolean performFinalReduce
+        QueryPhaseRankCoordinatorContext queryPhaseRankCoordinatorContext
     ) {
         assert numReducePhases >= 0 : "num reduce phases must be >= 0 but was: " + numReducePhases;
         numReducePhases++; // increment for this phase
@@ -520,12 +517,7 @@ public final class SearchPhaseController {
             topDocsStats.timedOut,
             topDocsStats.terminatedEarly,
             reducedSuggest,
-            bufferedAggs == null
-                ? null
-                : InternalAggregations.topLevelReduceDelayable(
-                    bufferedAggs,
-                    performFinalReduce ? aggReduceContextBuilder.forFinalReduction() : aggReduceContextBuilder.forPartialReduction()
-                ),
+            reducedAggs,
             profileShardResults.isEmpty() ? null : new SearchProfileResultsBuilder(profileShardResults),
             sortedTopDocs,
             sortValueFormats,

+ 1 - 23
server/src/main/java/org/elasticsearch/action/search/SearchScrollAsyncAction.java

@@ -20,7 +20,6 @@ import org.elasticsearch.common.util.concurrent.CountDown;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.search.SearchPhaseResult;
 import org.elasticsearch.search.SearchShardTarget;
-import org.elasticsearch.search.aggregations.AggregationReduceContext;
 import org.elasticsearch.search.internal.InternalScrollSearchRequest;
 import org.elasticsearch.search.internal.SearchContext;
 import org.elasticsearch.search.internal.ShardSearchContextId;
@@ -313,17 +312,6 @@ abstract class SearchScrollAsyncAction<T extends SearchPhaseResult> {
      * @param queryResults a list of non-null query shard results
      */
     protected static SearchPhaseController.ReducedQueryPhase reducedScrollQueryPhase(Collection<? extends SearchPhaseResult> queryResults) {
-        AggregationReduceContext.Builder aggReduceContextBuilder = new AggregationReduceContext.Builder() {
-            @Override
-            public AggregationReduceContext forPartialReduction() {
-                throw new UnsupportedOperationException("Scroll requests don't have aggs");
-            }
-
-            @Override
-            public AggregationReduceContext forFinalReduction() {
-                throw new UnsupportedOperationException("Scroll requests don't have aggs");
-            }
-        };
         final SearchPhaseController.TopDocsStats topDocsStats = new SearchPhaseController.TopDocsStats(
             SearchContext.TRACK_TOTAL_HITS_ACCURATE
         );
@@ -339,16 +327,6 @@ abstract class SearchScrollAsyncAction<T extends SearchPhaseResult> {
                 topDocs.add(td.topDocs);
             }
         }
-        return SearchPhaseController.reducedQueryPhase(
-            queryResults,
-            null,
-            topDocs,
-            topDocsStats,
-            0,
-            true,
-            aggReduceContextBuilder,
-            null,
-            true
-        );
+        return SearchPhaseController.reducedQueryPhase(queryResults, null, topDocs, topDocsStats, 0, true, null);
     }
 }

+ 31 - 33
server/src/main/java/org/elasticsearch/search/aggregations/InternalAggregations.java

@@ -8,7 +8,6 @@
  */
 package org.elasticsearch.search.aggregations;
 
-import org.elasticsearch.common.io.stream.DelayableWriteable;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
@@ -23,7 +22,6 @@ import org.elasticsearch.xcontent.ToXContentFragment;
 import org.elasticsearch.xcontent.XContentBuilder;
 
 import java.io.IOException;
-import java.util.AbstractList;
 import java.util.ArrayList;
 import java.util.Iterator;
 import java.util.List;
@@ -180,44 +178,22 @@ public final class InternalAggregations implements Iterable<InternalAggregation>
     }
 
     /**
-     * Equivalent to {@link #topLevelReduce(List, AggregationReduceContext)} but it takes a list of
-     * {@link DelayableWriteable}. The object will be expanded once via {@link DelayableWriteable#expand()}
-     * but it is the responsibility of the caller to release those releasables.
+     * Equivalent to {@link #topLevelReduce(List, AggregationReduceContext)} but it takes an iterator and a count.
      */
-    public static InternalAggregations topLevelReduceDelayable(
-        List<DelayableWriteable<InternalAggregations>> delayableAggregations,
-        AggregationReduceContext context
-    ) {
-        final List<InternalAggregations> aggregations = new AbstractList<>() {
-            @Override
-            public InternalAggregations get(int index) {
-                return delayableAggregations.get(index).expand();
-            }
-
-            @Override
-            public int size() {
-                return delayableAggregations.size();
-            }
-        };
-        return topLevelReduce(aggregations, context);
+    public static InternalAggregations topLevelReduce(Iterator<InternalAggregations> aggs, int count, AggregationReduceContext context) {
+        if (count == 0) {
+            return null;
+        }
+        return maybeExecuteFinalReduce(context, count == 1 ? reduce(aggs.next(), context) : reduce(aggs, count, context));
     }
 
-    /**
-     * Begin the reduction process.  This should be the entry point for the "first" reduction, e.g. called by
-     * SearchPhaseController or anywhere else that wants to initiate a reduction.  It _should not_ be called
-     * as an intermediate reduction step (e.g. in the middle of an aggregation tree).
-     *
-     * This method first reduces the aggregations, and if it is the final reduce, then reduce the pipeline
-     * aggregations (both embedded parent/sibling as well as top-level sibling pipelines)
-     */
-    public static InternalAggregations topLevelReduce(List<InternalAggregations> aggregationsList, AggregationReduceContext context) {
-        InternalAggregations reduced = reduce(aggregationsList, context);
+    private static InternalAggregations maybeExecuteFinalReduce(AggregationReduceContext context, InternalAggregations reduced) {
         if (reduced == null) {
             return null;
         }
         if (context.isFinalReduce()) {
-            List<InternalAggregation> reducedInternalAggs = reduced.getInternalAggregations();
-            reducedInternalAggs = reducedInternalAggs.stream()
+            List<InternalAggregation> reducedInternalAggs = reduced.getInternalAggregations()
+                .stream()
                 .map(agg -> agg.reducePipelines(agg, context, context.pipelineTreeRoot().subTree(agg.getName())))
                 .collect(Collectors.toCollection(ArrayList::new));
 
@@ -231,6 +207,18 @@ public final class InternalAggregations implements Iterable<InternalAggregation>
         return reduced;
     }
 
+    /**
+     * Begin the reduction process.  This should be the entry point for the "first" reduction, e.g. called by
+     * SearchPhaseController or anywhere else that wants to initiate a reduction.  It _should not_ be called
+     * as an intermediate reduction step (e.g. in the middle of an aggregation tree).
+     *
+     * This method first reduces the aggregations, and if it is the final reduce, then reduce the pipeline
+     * aggregations (both embedded parent/sibling as well as top-level sibling pipelines)
+     */
+    public static InternalAggregations topLevelReduce(List<InternalAggregations> aggregationsList, AggregationReduceContext context) {
+        return maybeExecuteFinalReduce(context, reduce(aggregationsList, context));
+    }
+
     /**
      * Reduces the given list of aggregations as well as the top-level pipeline aggregators extracted from the first
      * {@link InternalAggregations} object found in the list.
@@ -254,6 +242,16 @@ public final class InternalAggregations implements Iterable<InternalAggregation>
         }
     }
 
+    private static InternalAggregations reduce(Iterator<InternalAggregations> aggsIterator, int count, AggregationReduceContext context) {
+        // general case
+        var first = aggsIterator.next();
+        try (AggregatorsReducer reducer = new AggregatorsReducer(first, context, count)) {
+            reducer.accept(first);
+            aggsIterator.forEachRemaining(reducer::accept);
+            return reducer.get();
+        }
+    }
+
     public static InternalAggregations reduce(InternalAggregations aggregations, AggregationReduceContext context) {
         final List<InternalAggregation> internalAggregations = aggregations.asList();
         int size = internalAggregations.size();

+ 9 - 0
server/src/main/java/org/elasticsearch/search/query/QuerySearchResult.java

@@ -236,6 +236,15 @@ public final class QuerySearchResult extends SearchPhaseResult {
         return aggregations;
     }
 
+    public DelayableWriteable<InternalAggregations> consumeAggs() {
+        if (aggregations == null) {
+            throw new IllegalStateException("aggs already released");
+        }
+        var res = aggregations;
+        aggregations = null;
+        return res;
+    }
+
     /**
      * Release the memory hold by the {@link DelayableWriteable} aggregations
      * @throws IllegalStateException if {@link #releaseAggs()} has already being called.

+ 4 - 9
server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java

@@ -68,7 +68,6 @@ import org.elasticsearch.search.suggest.phrase.PhraseSuggestion;
 import org.elasticsearch.search.suggest.term.TermSuggestion;
 import org.elasticsearch.tasks.TaskId;
 import org.elasticsearch.test.ESTestCase;
-import org.elasticsearch.test.InternalAggregationTestCase;
 import org.elasticsearch.threadpool.TestThreadPool;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.TransportMessage;
@@ -273,14 +272,12 @@ public class SearchPhaseControllerTests extends ESTestCase {
             try {
                 SearchPhaseController.ReducedQueryPhase reducedQueryPhase = SearchPhaseController.reducedQueryPhase(
                     queryResults.asList(),
-                    new ArrayList<>(),
+                    InternalAggregations.EMPTY,
                     new ArrayList<>(),
                     new TopDocsStats(trackTotalHits),
                     0,
                     true,
-                    InternalAggregationTestCase.emptyReduceContextBuilder(),
-                    null,
-                    true
+                    null
                 );
                 List<SearchShardTarget> shards = queryResults.asList()
                     .stream()
@@ -363,12 +360,11 @@ public class SearchPhaseControllerTests extends ESTestCase {
             try {
                 SearchPhaseController.ReducedQueryPhase reducedQueryPhase = SearchPhaseController.reducedQueryPhase(
                     queryResults.asList(),
-                    new ArrayList<>(),
+                    InternalAggregations.EMPTY,
                     new ArrayList<>(),
                     new TopDocsStats(trackTotalHits),
                     0,
                     true,
-                    InternalAggregationTestCase.emptyReduceContextBuilder(),
                     new QueryPhaseRankCoordinatorContext(windowSize) {
                         @Override
                         public ScoreDoc[] rankQueryPhaseResults(List<QuerySearchResult> querySearchResults, TopDocsStats topDocStats) {
@@ -395,8 +391,7 @@ public class SearchPhaseControllerTests extends ESTestCase {
                             topDocStats.fetchHits = topResults.length;
                             return topResults;
                         }
-                    },
-                    true
+                    }
                 );
                 List<SearchShardTarget> shards = queryResults.asList()
                     .stream()