Browse Source

Refactoring on merging InternalTerms (#107049)

This refactor introduces a TermsAggregationReducer that holds the logic to merge InternalTerms. The main difference 
is that we are accumulating the buckets now instead of the internal aggregations.
Ignacio Vera 1 year ago
parent
commit
43efc95057

+ 98 - 93
server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/AbstractInternalTerms.java

@@ -12,6 +12,7 @@ import org.apache.lucene.util.PriorityQueue;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.search.aggregations.AggregationErrors;
 import org.elasticsearch.search.aggregations.AggregationReduceContext;
+import org.elasticsearch.search.aggregations.AggregatorReducer;
 import org.elasticsearch.search.aggregations.BucketOrder;
 import org.elasticsearch.search.aggregations.DelayedBucket;
 import org.elasticsearch.search.aggregations.InternalAggregation;
@@ -112,23 +113,6 @@ public abstract class AbstractInternalTerms<A extends AbstractInternalTerms<A, B
         return createBucket(docCount, aggs, docCountError, buckets.get(0));
     }
 
-    private BucketOrder getReduceOrder(List<InternalAggregation> aggregations) {
-        BucketOrder thisReduceOrder = null;
-        for (InternalAggregation aggregation : aggregations) {
-            @SuppressWarnings("unchecked")
-            A terms = (A) aggregation;
-            if (terms.getBuckets().size() == 0) {
-                continue;
-            }
-            if (thisReduceOrder == null) {
-                thisReduceOrder = terms.getReduceOrder();
-            } else if (thisReduceOrder.equals(terms.getReduceOrder()) == false) {
-                return getOrder();
-            }
-        }
-        return thisReduceOrder != null ? thisReduceOrder : getOrder();
-    }
-
     private long getDocCountError(A terms) {
         int size = terms.getBuckets().size();
         if (size == 0 || size < terms.getShardSize() || isKeyOrder(terms.getOrder())) {
@@ -154,47 +138,37 @@ public abstract class AbstractInternalTerms<A extends AbstractInternalTerms<A, B
      * @return the order we used to reduce the buckets
      */
     private BucketOrder reduceBuckets(
-        List<InternalAggregation> aggregations,
+        List<List<B>> bucketsList,
+        BucketOrder thisReduceOrder,
         AggregationReduceContext reduceContext,
         Consumer<DelayedBucket<B>> sink
     ) {
-        /*
-         * Buckets returned by a partial reduce or a shard response are sorted by key since {@link Version#V_7_10_0}.
-         * That allows to perform a merge sort when reducing multiple aggregations together.
-         * For backward compatibility, we disable the merge sort and use ({@link #reduceLegacy} if any of
-         * the provided aggregations use a different {@link #reduceOrder}.
-         */
-        BucketOrder thisReduceOrder = getReduceOrder(aggregations);
         if (isKeyOrder(thisReduceOrder)) {
             // extract the primary sort in case this is a compound order.
             thisReduceOrder = InternalOrder.key(isKeyAsc(thisReduceOrder));
-            reduceMergeSort(aggregations, thisReduceOrder, reduceContext, sink);
+            reduceMergeSort(bucketsList, thisReduceOrder, reduceContext, sink);
         } else {
-            reduceLegacy(aggregations, reduceContext, sink);
+            reduceLegacy(bucketsList, reduceContext, sink);
         }
         return thisReduceOrder;
     }
 
     private void reduceMergeSort(
-        List<InternalAggregation> aggregations,
+        List<List<B>> bucketsList,
         BucketOrder thisReduceOrder,
         AggregationReduceContext reduceContext,
         Consumer<DelayedBucket<B>> sink
     ) {
         assert isKeyOrder(thisReduceOrder);
         final Comparator<Bucket> cmp = thisReduceOrder.comparator();
-        final PriorityQueue<IteratorAndCurrent<B>> pq = new PriorityQueue<>(aggregations.size()) {
+        final PriorityQueue<IteratorAndCurrent<B>> pq = new PriorityQueue<>(bucketsList.size()) {
             @Override
             protected boolean lessThan(IteratorAndCurrent<B> a, IteratorAndCurrent<B> b) {
                 return cmp.compare(a.current(), b.current()) < 0;
             }
         };
-        for (InternalAggregation aggregation : aggregations) {
-            @SuppressWarnings("unchecked")
-            A terms = (A) aggregation;
-            if (terms.getBuckets().isEmpty() == false) {
-                pq.add(new IteratorAndCurrent<>(terms.getBuckets().iterator()));
-            }
+        for (List<B> buckets : bucketsList) {
+            pq.add(new IteratorAndCurrent<>(buckets.iterator()));
         }
         // list of buckets coming from different shards that have the same key
         List<B> sameTermBuckets = new ArrayList<>();
@@ -228,19 +202,11 @@ public abstract class AbstractInternalTerms<A extends AbstractInternalTerms<A, B
         }
     }
 
-    private void reduceLegacy(
-        List<InternalAggregation> aggregations,
-        AggregationReduceContext reduceContext,
-        Consumer<DelayedBucket<B>> sink
-    ) {
-        Map<Object, List<B>> bucketMap = new HashMap<>();
-        for (InternalAggregation aggregation : aggregations) {
-            @SuppressWarnings("unchecked")
-            A terms = (A) aggregation;
-            if (terms.getBuckets().isEmpty() == false) {
-                for (B bucket : terms.getBuckets()) {
-                    bucketMap.computeIfAbsent(bucket.getKey(), k -> new ArrayList<>()).add(bucket);
-                }
+    private void reduceLegacy(List<List<B>> bucketsList, AggregationReduceContext reduceContext, Consumer<DelayedBucket<B>> sink) {
+        final Map<Object, List<B>> bucketMap = new HashMap<>();
+        for (List<B> buckets : bucketsList) {
+            for (B bucket : buckets) {
+                bucketMap.computeIfAbsent(bucket.getKey(), k -> new ArrayList<>()).add(bucket);
             }
         }
         for (List<B> sameTermBuckets : bucketMap.values()) {
@@ -248,21 +214,49 @@ public abstract class AbstractInternalTerms<A extends AbstractInternalTerms<A, B
         }
     }
 
-    public InternalAggregation doReduce(List<InternalAggregation> aggregations, AggregationReduceContext reduceContext) {
-        long sumDocCountError = 0;
-        long[] otherDocCount = new long[] { 0 };
-        A referenceTerms = null;
-        for (InternalAggregation aggregation : aggregations) {
+    public final AggregatorReducer termsAggregationReducer(AggregationReduceContext reduceContext, int size) {
+        return new TermsAggregationReducer(reduceContext, size);
+    }
+
+    private class TermsAggregationReducer implements AggregatorReducer {
+        private final List<List<B>> bucketsList;
+        private final AggregationReduceContext reduceContext;
+
+        private long sumDocCountError = 0;
+        private final long[] otherDocCount = new long[] { 0 };
+        private A referenceTerms = null;
+        /*
+         * Buckets returned by a partial reduce or a shard response are sorted by key since {@link Version#V_7_10_0}.
+         * That allows to perform a merge sort when reducing multiple aggregations together.
+         * For backward compatibility, we disable the merge sort and use ({@link #reduceLegacy} if any of
+         * the provided aggregations use a different {@link #reduceOrder}.
+         */
+        private BucketOrder thisReduceOrder = null;
+
+        private TermsAggregationReducer(AggregationReduceContext reduceContext, int size) {
+            bucketsList = new ArrayList<>(size);
+            this.reduceContext = reduceContext;
+        }
+
+        @Override
+        public void accept(InternalAggregation aggregation) {
+            if (aggregation.canLeadReduction() == false) {
+                return;
+            }
             @SuppressWarnings("unchecked")
             A terms = (A) aggregation;
-            if (referenceTerms == null && terms.canLeadReduction()) {
+            if (referenceTerms == null) {
                 referenceTerms = terms;
-            }
-            if (referenceTerms != null && referenceTerms.getClass().equals(terms.getClass()) == false && terms.canLeadReduction()) {
+            } else if (referenceTerms.getClass().equals(terms.getClass()) == false) {
                 // control gets into this loop when the same field name against which the query is executed
                 // is of different types in different indices.
                 throw AggregationErrors.reduceTypeMismatch(referenceTerms.getName(), Optional.empty());
             }
+            if (thisReduceOrder == null) {
+                thisReduceOrder = terms.getReduceOrder();
+            } else if (thisReduceOrder != getOrder() && thisReduceOrder.equals(terms.getReduceOrder()) == false) {
+                thisReduceOrder = getOrder();
+            }
             otherDocCount[0] += terms.getSumOfOtherDocCounts();
             final long thisAggDocCountError = getDocCountError(terms);
             if (sumDocCountError != -1) {
@@ -283,52 +277,63 @@ public abstract class AbstractInternalTerms<A extends AbstractInternalTerms<A, B
                 // later in this method.
                 bucket.updateDocCountError(-thisAggDocCountError);
             }
+            if (terms.getBuckets().isEmpty() == false) {
+                bucketsList.add(terms.getBuckets());
+            }
         }
 
-        BucketOrder thisReduceOrder;
-        List<B> result;
-        if (reduceContext.isFinalReduce()) {
-            TopBucketBuilder<B> top = TopBucketBuilder.build(
-                getRequiredSize(),
-                getOrder(),
-                removed -> otherDocCount[0] += removed.getDocCount()
-            );
-            thisReduceOrder = reduceBuckets(aggregations, reduceContext, bucket -> {
-                if (bucket.getDocCount() >= getMinDocCount()) {
-                    top.add(bucket);
-                }
-            });
-            result = top.build();
-        } else {
-            /*
-             * We can prune the list on partial reduce if the aggregation is ordered
-             * by key and not filtered on doc count. The results come in key order
-             * so we can just stop iteration early.
-             */
-            boolean canPrune = isKeyOrder(getOrder()) && getMinDocCount() == 0;
-            result = new ArrayList<>();
-            thisReduceOrder = reduceBuckets(aggregations, reduceContext, bucket -> {
-                if (canPrune == false || result.size() < getRequiredSize()) {
-                    result.add(bucket.reduced());
+        @Override
+        public InternalAggregation get() {
+            BucketOrder thisReduceOrder;
+            List<B> result;
+            if (isKeyOrder(getOrder()) && getMinDocCount() <= 1) {
+                /*
+                 * the aggregation is order by key and not filtered on doc count. The results come in key order
+                 * so we can just have an optimize collection.
+                 */
+                result = new ArrayList<>();
+                thisReduceOrder = reduceBuckets(bucketsList, getThisReduceOrder(), reduceContext, bucket -> {
+                    if (result.size() < getRequiredSize()) {
+                        result.add(bucket.reduced());
+                    } else {
+                        otherDocCount[0] += bucket.getDocCount();
+                    }
+                });
+            } else if (reduceContext.isFinalReduce()) {
+                TopBucketBuilder<B> top = TopBucketBuilder.build(
+                    getRequiredSize(),
+                    getOrder(),
+                    removed -> otherDocCount[0] += removed.getDocCount()
+                );
+                thisReduceOrder = reduceBuckets(bucketsList, getThisReduceOrder(), reduceContext, bucket -> {
+                    if (bucket.getDocCount() >= getMinDocCount()) {
+                        top.add(bucket);
+                    }
+                });
+                result = top.build();
+            } else {
+                result = new ArrayList<>();
+                thisReduceOrder = reduceBuckets(bucketsList, getThisReduceOrder(), reduceContext, bucket -> result.add(bucket.reduced()));
+            }
+            for (B r : result) {
+                if (sumDocCountError == -1) {
+                    r.setDocCountError(-1);
                 } else {
-                    otherDocCount[0] += bucket.getDocCount();
+                    r.updateDocCountError(sumDocCountError);
                 }
-            });
-        }
-        for (B r : result) {
+            }
+            long docCountError;
             if (sumDocCountError == -1) {
-                r.setDocCountError(-1);
+                docCountError = -1;
             } else {
-                r.updateDocCountError(sumDocCountError);
+                docCountError = bucketsList.size() == 1 ? 0 : sumDocCountError;
             }
+            return create(name, result, reduceContext.isFinalReduce() ? getOrder() : thisReduceOrder, docCountError, otherDocCount[0]);
         }
-        long docCountError;
-        if (sumDocCountError == -1) {
-            docCountError = -1;
-        } else {
-            docCountError = aggregations.size() == 1 ? 0 : sumDocCountError;
+
+        private BucketOrder getThisReduceOrder() {
+            return thisReduceOrder == null ? getOrder() : thisReduceOrder;
         }
-        return create(name, result, reduceContext.isFinalReduce() ? getOrder() : thisReduceOrder, docCountError, otherDocCount[0]);
     }
 
     @Override

+ 10 - 6
server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/DoubleTerms.java

@@ -9,6 +9,7 @@ package org.elasticsearch.search.aggregations.bucket.terms;
 
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.Releasables;
 import org.elasticsearch.search.DocValueFormat;
 import org.elasticsearch.search.aggregations.AggregationReduceContext;
 import org.elasticsearch.search.aggregations.AggregatorReducer;
@@ -18,7 +19,6 @@ import org.elasticsearch.search.aggregations.InternalAggregations;
 import org.elasticsearch.xcontent.XContentBuilder;
 
 import java.io.IOException;
-import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
@@ -190,21 +190,25 @@ public class DoubleTerms extends InternalMappedTerms<DoubleTerms, DoubleTerms.Bu
     @Override
     protected AggregatorReducer getLeaderReducer(AggregationReduceContext reduceContext, int size) {
         return new AggregatorReducer() {
-            private final List<InternalAggregation> aggregations = new ArrayList<>();
+            private final AggregatorReducer processor = termsAggregationReducer(reduceContext, size);
 
             @Override
             public void accept(InternalAggregation aggregation) {
                 if (aggregation instanceof LongTerms longTerms) {
-                    DoubleTerms dTerms = LongTerms.convertLongTermsToDouble(longTerms, format);
-                    aggregations.add(dTerms);
+                    processor.accept(LongTerms.convertLongTermsToDouble(longTerms, format));
                 } else {
-                    aggregations.add(aggregation);
+                    processor.accept(aggregation);
                 }
             }
 
             @Override
             public InternalAggregation get() {
-                return ((AbstractInternalTerms<?, ?>) aggregations.get(0)).doReduce(aggregations, reduceContext);
+                return processor.get();
+            }
+
+            @Override
+            public void close() {
+                Releasables.close(processor);
             }
         };
     }

+ 12 - 3
server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/LongTerms.java

@@ -213,8 +213,8 @@ public class LongTerms extends InternalMappedTerms<LongTerms, LongTerms.Bucket>
         }
         return new AggregatorReducer() {
 
-            final List<InternalAggregation> aggregations = new ArrayList<>(size);
-            boolean isPromotedToDouble = false;
+            private List<InternalAggregation> aggregations = new ArrayList<>(size);
+            private boolean isPromotedToDouble = false;
 
             @Override
             public void accept(InternalAggregation aggregation) {
@@ -243,7 +243,16 @@ public class LongTerms extends InternalMappedTerms<LongTerms, LongTerms.Bucket>
 
             @Override
             public InternalAggregation get() {
-                return ((AbstractInternalTerms<?, ?>) aggregations.get(0)).doReduce(aggregations, reduceContext);
+                try (
+                    AggregatorReducer processor = ((AbstractInternalTerms<?, ?>) aggregations.get(0)).termsAggregationReducer(
+                        reduceContext,
+                        size
+                    )
+                ) {
+                    aggregations.forEach(processor::accept);
+                    aggregations = null; // release memory
+                    return processor.get();
+                }
             }
         };
     }

+ 1 - 15
server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/StringTerms.java

@@ -14,12 +14,10 @@ import org.elasticsearch.search.DocValueFormat;
 import org.elasticsearch.search.aggregations.AggregationReduceContext;
 import org.elasticsearch.search.aggregations.AggregatorReducer;
 import org.elasticsearch.search.aggregations.BucketOrder;
-import org.elasticsearch.search.aggregations.InternalAggregation;
 import org.elasticsearch.search.aggregations.InternalAggregations;
 import org.elasticsearch.xcontent.XContentBuilder;
 
 import java.io.IOException;
-import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
@@ -153,19 +151,7 @@ public class StringTerms extends InternalMappedTerms<StringTerms, StringTerms.Bu
 
     @Override
     protected AggregatorReducer getLeaderReducer(AggregationReduceContext reduceContext, int size) {
-        return new AggregatorReducer() {
-            private final List<InternalAggregation> aggregations = new ArrayList<>(size);
-
-            @Override
-            public void accept(InternalAggregation aggregation) {
-                aggregations.add(aggregation);
-            }
-
-            @Override
-            public InternalAggregation get() {
-                return ((AbstractInternalTerms<?, ?>) aggregations.get(0)).doReduce(aggregations, reduceContext);
-            }
-        };
+        return termsAggregationReducer(reduceContext, size);
     }
 
     @Override

+ 60 - 53
x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/multiterms/InternalMultiTerms.java

@@ -439,44 +439,6 @@ public class InternalMultiTerms extends AbstractInternalTerms<InternalMultiTerms
         );
     }
 
-    /**
-     * Checks if any keys need to be promoted to double from long or unsigned_long
-     */
-    private boolean[] needsPromotionToDouble(List<InternalAggregation> aggregations) {
-        if (aggregations.size() < 2) {
-            return null;
-        }
-        boolean[] promotions = null;
-
-        for (int i = 0; i < keyConverters.size(); i++) {
-            boolean hasLong = false;
-            boolean hasUnsignedLong = false;
-            boolean hasDouble = false;
-            boolean hasNonNumber = false;
-            for (InternalAggregation aggregation : aggregations) {
-                InternalMultiTerms agg = (InternalMultiTerms) aggregation;
-                KeyConverter keyConverter = agg.keyConverters.get(i);
-                switch (keyConverter) {
-                    case DOUBLE -> hasDouble = true;
-                    case LONG -> hasLong = true;
-                    case UNSIGNED_LONG -> hasUnsignedLong = true;
-                    default -> hasNonNumber = true;
-                }
-            }
-            if (hasNonNumber && (hasDouble || hasUnsignedLong || hasLong)) {
-                throw AggregationErrors.reduceTypeMismatch(name, Optional.of(i + 1));
-            }
-            // Promotion to double is required if at least 2 of these 3 conditions are true.
-            if ((hasDouble ? 1 : 0) + (hasUnsignedLong ? 1 : 0) + (hasLong ? 1 : 0) > 1) {
-                if (promotions == null) {
-                    promotions = new boolean[keyConverters.size()];
-                }
-                promotions[i] = true;
-            }
-        }
-        return promotions;
-    }
-
     private InternalAggregation promoteToDouble(InternalAggregation aggregation, boolean[] needsPromotion) {
         InternalMultiTerms multiTerms = (InternalMultiTerms) aggregation;
         List<Bucket> multiTermsBuckets = multiTerms.getBuckets();
@@ -539,33 +501,78 @@ public class InternalMultiTerms extends AbstractInternalTerms<InternalMultiTerms
         );
     }
 
-    public List<InternalAggregation> getProcessedAggs(List<InternalAggregation> aggregations, boolean[] needsPromotionToDouble) {
-        if (needsPromotionToDouble != null) {
-            List<InternalAggregation> newAggs = new ArrayList<>(aggregations.size());
-            for (InternalAggregation agg : aggregations) {
-                newAggs.add(promoteToDouble(agg, needsPromotionToDouble));
-            }
-            return newAggs;
-        } else {
-            return aggregations;
-        }
-    }
-
     @Override
     protected AggregatorReducer getLeaderReducer(AggregationReduceContext reduceContext, int size) {
         return new AggregatorReducer() {
 
-            final List<InternalAggregation> aggregations = new ArrayList<>(size);
+            private List<InternalAggregation> aggregations = new ArrayList<>(size);
 
             @Override
             public void accept(InternalAggregation aggregation) {
                 aggregations.add(aggregation);
             }
 
+            private List<InternalAggregation> getProcessedAggs(List<InternalAggregation> aggregations, boolean[] needsPromotionToDouble) {
+                if (needsPromotionToDouble != null) {
+                    aggregations.replaceAll(agg -> promoteToDouble(agg, needsPromotionToDouble));
+                }
+                return aggregations;
+            }
+
+            /**
+             * Checks if any keys need to be promoted to double from long or unsigned_long
+             */
+            private boolean[] needsPromotionToDouble(List<InternalAggregation> aggregations) {
+                if (aggregations.size() < 2) {
+                    return null;
+                }
+                boolean[] promotions = null;
+
+                for (int i = 0; i < keyConverters.size(); i++) {
+                    boolean hasLong = false;
+                    boolean hasUnsignedLong = false;
+                    boolean hasDouble = false;
+                    boolean hasNonNumber = false;
+                    for (InternalAggregation aggregation : aggregations) {
+                        InternalMultiTerms agg = (InternalMultiTerms) aggregation;
+                        KeyConverter keyConverter = agg.keyConverters.get(i);
+                        switch (keyConverter) {
+                            case DOUBLE -> hasDouble = true;
+                            case LONG -> hasLong = true;
+                            case UNSIGNED_LONG -> hasUnsignedLong = true;
+                            default -> hasNonNumber = true;
+                        }
+                    }
+                    if (hasNonNumber && (hasDouble || hasUnsignedLong || hasLong)) {
+                        throw AggregationErrors.reduceTypeMismatch(name, Optional.of(i + 1));
+                    }
+                    // Promotion to double is required if at least 2 of these 3 conditions are true.
+                    if ((hasDouble ? 1 : 0) + (hasUnsignedLong ? 1 : 0) + (hasLong ? 1 : 0) > 1) {
+                        if (promotions == null) {
+                            promotions = new boolean[keyConverters.size()];
+                        }
+                        promotions[i] = true;
+                    }
+                }
+                return promotions;
+            }
+
             @Override
             public InternalAggregation get() {
-                List<InternalAggregation> processed = getProcessedAggs(aggregations, needsPromotionToDouble(aggregations));
-                return ((AbstractInternalTerms<?, ?>) processed.get(0)).doReduce(processed, reduceContext);
+                final boolean[] needsPromotionToDouble = needsPromotionToDouble(aggregations);
+                if (needsPromotionToDouble != null) {
+                    aggregations.replaceAll(agg -> promoteToDouble(agg, needsPromotionToDouble));
+                }
+                try (
+                    AggregatorReducer processor = ((AbstractInternalTerms<?, ?>) aggregations.get(0)).termsAggregationReducer(
+                        reduceContext,
+                        size
+                    )
+                ) {
+                    aggregations.forEach(processor::accept);
+                    aggregations = null; // release memory
+                    return processor.get();
+                }
             }
         };
     }