Browse Source

Drier and faster SumAggregator and AvgAggregator (#120436) (#120610)

Dried up (and moved to the much faster inline logic) for the summation here for both implementations.
Obviously this could have been done even drier but it didn't seem like that was possible without a performance
hit (we really don't want to sub-class the leaf-collector I think).
Benchmarks suggest this variant is ~10% faster than the previous iteration of `SumAggregator` (probably from
making the grow method smaller) and a bigger than that improvement for the `AvgAggregator`.
Armin Braun 9 months ago
parent
commit
7cd58fa401

+ 10 - 37
server/src/main/java/org/elasticsearch/search/aggregations/metrics/AvgAggregator.java

@@ -9,12 +9,10 @@
 package org.elasticsearch.search.aggregations.metrics;
 
 import org.elasticsearch.common.util.BigArrays;
-import org.elasticsearch.common.util.DoubleArray;
 import org.elasticsearch.common.util.LongArray;
 import org.elasticsearch.core.Releasables;
 import org.elasticsearch.index.fielddata.NumericDoubleValues;
 import org.elasticsearch.index.fielddata.SortedNumericDoubleValues;
-import org.elasticsearch.search.DocValueFormat;
 import org.elasticsearch.search.aggregations.Aggregator;
 import org.elasticsearch.search.aggregations.InternalAggregation;
 import org.elasticsearch.search.aggregations.LeafBucketCollector;
@@ -25,12 +23,9 @@ import org.elasticsearch.search.aggregations.support.ValuesSourceConfig;
 import java.io.IOException;
 import java.util.Map;
 
-class AvgAggregator extends NumericMetricsAggregator.SingleDoubleValue {
+class AvgAggregator extends SumAggregator {
 
     LongArray counts;
-    DoubleArray sums;
-    DoubleArray compensations;
-    DocValueFormat format;
 
     AvgAggregator(
         String name,
@@ -40,32 +35,17 @@ class AvgAggregator extends NumericMetricsAggregator.SingleDoubleValue {
         Map<String, Object> metadata
     ) throws IOException {
         super(name, valuesSourceConfig, context, parent, metadata);
-        assert valuesSourceConfig.hasValues();
-        this.format = valuesSourceConfig.format();
-        final BigArrays bigArrays = context.bigArrays();
-        counts = bigArrays.newLongArray(1, true);
-        sums = bigArrays.newDoubleArray(1, true);
-        compensations = bigArrays.newDoubleArray(1, true);
+        counts = context.bigArrays().newLongArray(1, true);
     }
 
     @Override
     protected LeafBucketCollector getLeafCollector(SortedNumericDoubleValues values, final LeafBucketCollector sub) {
-        final CompensatedSum kahanSummation = new CompensatedSum(0, 0);
         return new LeafBucketCollectorBase(sub, values) {
             @Override
             public void collect(int doc, long bucket) throws IOException {
                 if (values.advanceExact(doc)) {
                     maybeGrow(bucket);
-                    final int valueCount = values.docValueCount();
-                    counts.increment(bucket, valueCount);
-                    // Compute the sum of double values with Kahan summation algorithm which is more
-                    // accurate than naive summation.
-                    kahanSummation.reset(sums.get(bucket), compensations.get(bucket));
-                    for (int i = 0; i < valueCount; i++) {
-                        kahanSummation.add(values.nextValue());
-                    }
-                    sums.set(bucket, kahanSummation.value());
-                    compensations.set(bucket, kahanSummation.delta());
+                    counts.increment(bucket, sumSortedDoubles(bucket, values, sums, compensations));
                 }
             }
         };
@@ -73,30 +53,22 @@ class AvgAggregator extends NumericMetricsAggregator.SingleDoubleValue {
 
     @Override
     protected LeafBucketCollector getLeafCollector(NumericDoubleValues values, final LeafBucketCollector sub) {
-        final CompensatedSum kahanSummation = new CompensatedSum(0, 0);
         return new LeafBucketCollectorBase(sub, values) {
             @Override
             public void collect(int doc, long bucket) throws IOException {
                 if (values.advanceExact(doc)) {
                     maybeGrow(bucket);
+                    computeSum(bucket, values, sums, compensations);
                     counts.increment(bucket, 1L);
-                    // Compute the sum of double values with Kahan summation algorithm which is more
-                    // accurate than naive summation.
-                    kahanSummation.reset(sums.get(bucket), compensations.get(bucket));
-                    kahanSummation.add(values.doubleValue());
-                    sums.set(bucket, kahanSummation.value());
-                    compensations.set(bucket, kahanSummation.delta());
                 }
             }
         };
     }
 
-    private void maybeGrow(long bucket) {
-        if (bucket >= counts.size()) {
-            counts = bigArrays().grow(counts, bucket + 1);
-            sums = bigArrays().grow(sums, bucket + 1);
-            compensations = bigArrays().grow(compensations, bucket + 1);
-        }
+    @Override
+    protected void doGrow(long bucket, BigArrays bigArrays) {
+        super.doGrow(bucket, bigArrays);
+        counts = bigArrays.grow(counts, bucket + 1);
     }
 
     @Override
@@ -122,7 +94,8 @@ class AvgAggregator extends NumericMetricsAggregator.SingleDoubleValue {
 
     @Override
     public void doClose() {
-        Releasables.close(counts, sums, compensations);
+        super.doClose();
+        Releasables.close(counts);
     }
 
 }

+ 66 - 40
server/src/main/java/org/elasticsearch/search/aggregations/metrics/SumAggregator.java

@@ -8,6 +8,7 @@
  */
 package org.elasticsearch.search.aggregations.metrics;
 
+import org.elasticsearch.common.util.BigArrays;
 import org.elasticsearch.common.util.DoubleArray;
 import org.elasticsearch.core.Releasables;
 import org.elasticsearch.index.fielddata.NumericDoubleValues;
@@ -25,10 +26,9 @@ import java.util.Map;
 
 public class SumAggregator extends NumericMetricsAggregator.SingleDoubleValue {
 
-    private final DocValueFormat format;
-
-    private DoubleArray sums;
-    private DoubleArray compensations;
+    protected final DocValueFormat format;
+    protected DoubleArray sums;
+    protected DoubleArray compensations;
 
     SumAggregator(
         String name,
@@ -40,31 +40,56 @@ public class SumAggregator extends NumericMetricsAggregator.SingleDoubleValue {
         super(name, valuesSourceConfig, context, parent, metadata);
         assert valuesSourceConfig.hasValues();
         this.format = valuesSourceConfig.format();
-        sums = bigArrays().newDoubleArray(1, true);
-        compensations = bigArrays().newDoubleArray(1, true);
+        var bigArrays = context.bigArrays();
+        sums = bigArrays.newDoubleArray(1, true);
+        compensations = bigArrays.newDoubleArray(1, true);
     }
 
     @Override
     protected LeafBucketCollector getLeafCollector(SortedNumericDoubleValues values, final LeafBucketCollector sub) {
-        final CompensatedSum kahanSummation = new CompensatedSum(0, 0);
         return new LeafBucketCollectorBase(sub, values) {
             @Override
             public void collect(int doc, long bucket) throws IOException {
                 if (values.advanceExact(doc)) {
                     maybeGrow(bucket);
-                    // Compute the sum of double values with Kahan summation algorithm which is more
-                    // accurate than naive summation.
-                    kahanSummation.reset(sums.get(bucket), compensations.get(bucket));
-                    for (int i = 0; i < values.docValueCount(); i++) {
-                        kahanSummation.add(values.nextValue());
-                    }
-                    compensations.set(bucket, kahanSummation.delta());
-                    sums.set(bucket, kahanSummation.value());
+                    sumSortedDoubles(bucket, values, sums, compensations);
                 }
             }
         };
     }
 
+    // returns number of values added
+    static int sumSortedDoubles(long bucket, SortedNumericDoubleValues values, DoubleArray sums, DoubleArray compensations)
+        throws IOException {
+        final int valueCount = values.docValueCount();
+        // Compute the sum of double values with Kahan summation algorithm which is more
+        // accurate than naive summation.
+        double value = sums.get(bucket);
+        double delta = compensations.get(bucket);
+        for (int i = 0; i < valueCount; i++) {
+            double added = values.nextValue();
+            value = addIfNonOrInf(added, value);
+            if (Double.isFinite(value)) {
+                double correctedSum = added + delta;
+                double updatedValue = value + correctedSum;
+                delta = correctedSum - (updatedValue - value);
+                value = updatedValue;
+            }
+        }
+        compensations.set(bucket, delta);
+        sums.set(bucket, value);
+        return valueCount;
+    }
+
+    private static double addIfNonOrInf(double added, double value) {
+        // If the value is Inf or NaN, just add it to the running tally to "convert" to
+        // Inf/NaN. This keeps the behavior bwc from before kahan summing
+        if (Double.isFinite(added)) {
+            return value;
+        }
+        return added + value;
+    }
+
     @Override
     protected LeafBucketCollector getLeafCollector(NumericDoubleValues values, final LeafBucketCollector sub) {
         return new LeafBucketCollectorBase(sub, values) {
@@ -72,40 +97,41 @@ public class SumAggregator extends NumericMetricsAggregator.SingleDoubleValue {
             public void collect(int doc, long bucket) throws IOException {
                 if (values.advanceExact(doc)) {
                     maybeGrow(bucket);
-                    var sums = SumAggregator.this.sums;
-                    // Compute the sum of double values with Kahan summation algorithm which is more
-                    // accurate than naive summation.
-                    double value = sums.get(bucket);
-                    // If the value is Inf or NaN, just add it to the running tally to "convert" to
-                    // Inf/NaN. This keeps the behavior bwc from before kahan summing
-                    double v = values.doubleValue();
-                    if (Double.isFinite(v) == false) {
-                        value = v + value;
-                    }
-
-                    if (Double.isFinite(value)) {
-                        var compensations = SumAggregator.this.compensations;
-                        double delta = compensations.get(bucket);
-                        double correctedSum = v + delta;
-                        double updatedValue = value + correctedSum;
-                        delta = correctedSum - (updatedValue - value);
-                        value = updatedValue;
-                        compensations.set(bucket, delta);
-                    }
-
-                    sums.set(bucket, value);
+                    computeSum(bucket, values, sums, compensations);
                 }
             }
         };
     }
 
-    private void maybeGrow(long bucket) {
+    static void computeSum(long bucket, NumericDoubleValues values, DoubleArray sums, DoubleArray compensations) throws IOException {
+        // Compute the sum of double values with Kahan summation algorithm which is more
+        // accurate than naive summation.
+        double added = values.doubleValue();
+        double value = addIfNonOrInf(added, sums.get(bucket));
+        if (Double.isFinite(value)) {
+            double delta = compensations.get(bucket);
+            double correctedSum = added + delta;
+            double updatedValue = value + correctedSum;
+            delta = correctedSum - (updatedValue - value);
+            value = updatedValue;
+            compensations.set(bucket, delta);
+        }
+
+        sums.set(bucket, value);
+    }
+
+    protected final void maybeGrow(long bucket) {
         if (bucket >= sums.size()) {
-            sums = bigArrays().grow(sums, bucket + 1);
-            compensations = bigArrays().grow(compensations, bucket + 1);
+            var bigArrays = bigArrays();
+            doGrow(bucket, bigArrays);
         }
     }
 
+    protected void doGrow(long bucket, BigArrays bigArrays) {
+        sums = bigArrays.grow(sums, bucket + 1);
+        compensations = bigArrays.grow(compensations, bucket + 1);
+    }
+
     @Override
     public double metric(long owningBucketOrd) {
         if (owningBucketOrd >= sums.size()) {