1
0
Эх сурвалжийг харах

Improve vwh's distant bucket handling (#59094)

This modifies the `variable_width_histogram`'s distant bucket handling
to:
1. Properly handle integer overflows
2. Recalculate the average distance when new buckets are added on the
   ends. This should slow down the rate at which we build extra buckets
   as we build more of them.
Nik Everett 5 жил өмнө
parent
commit
28ca127199

+ 18 - 14
server/src/main/java/org/elasticsearch/search/aggregations/bucket/histogram/VariableWidthHistogramAggregator.java

@@ -166,7 +166,7 @@ public class VariableWidthHistogramAggregator extends DeferableBucketAggregator
         public DoubleArray clusterSizes; // clusterSizes != bucketDocCounts when clusters are in the middle of a merge
         public int numClusters;
 
-        private int avgBucketDistance;
+        private double avgBucketDistance;
 
         MergeBucketsPhase(DoubleArray buffer, int bufferSize) {
             // Cluster the documents to reduce the number of buckets
@@ -174,15 +174,7 @@ public class VariableWidthHistogramAggregator extends DeferableBucketAggregator
             bucketBufferedDocs(buffer, bufferSize, shardSize * 3 / 4);
 
             if(bufferSize > 1) {
-                // Calculate the average distance between buckets
-                // Subsequent documents will be compared with this value to determine if they should be collected into
-                // an existing bucket or into a new bucket
-                // This can be done in a single linear scan because buckets are sorted by centroid
-                int sum = 0;
-                for (int i = 0; i < numClusters - 1; i++) {
-                    sum += clusterCentroids.get(i + 1) - clusterCentroids.get(i);
-                }
-                avgBucketDistance = (sum / (numClusters - 1));
+                updateAvgBucketDistance();
             }
         }
 
@@ -194,11 +186,9 @@ public class VariableWidthHistogramAggregator extends DeferableBucketAggregator
 
             final DoubleArray values;
             final long[] indexes;
-            int length;
 
             ClusterSorter(DoubleArray values, int length){
                 this.values = values;
-                this.length = length;
 
                 this.indexes = new long[length];
                 for(int i = 0; i < indexes.length; i++){
@@ -284,7 +274,7 @@ public class VariableWidthHistogramAggregator extends DeferableBucketAggregator
         @Override
         public CollectionPhase collectValue(LeafBucketCollector sub, int doc, double val) throws IOException{
             int bucketOrd = getNearestBucket(val);
-            double distance = Math.abs(clusterCentroids.get(bucketOrd)- val);
+            double distance = Math.abs(clusterCentroids.get(bucketOrd) - val);
             if(bucketOrd == -1 || distance > (2 * avgBucketDistance) && numClusters < shardSize) {
                 // Make a new bucket since the document is distant from all existing buckets
                 // TODO: (maybe) Create a new bucket for <b>all</b> distant docs and merge down to shardSize buckets at end
@@ -293,17 +283,31 @@ public class VariableWidthHistogramAggregator extends DeferableBucketAggregator
                 collectBucket(sub, doc, numClusters - 1);
 
                 if(val > clusterCentroids.get(bucketOrd)){
-                    // Insert just ahead of bucketOrd so that the array remains sorted
+                    /*
+                     * If the new value is bigger than the nearest bucket then insert
+                     * just ahead of bucketOrd so that the array remains sorted.
+                     */
                     bucketOrd += 1;
                 }
                 moveLastCluster(bucketOrd);
+                // We've added a new bucket so update the average distance between the buckets
+                updateAvgBucketDistance();
             } else {
                 addToCluster(bucketOrd, val);
                 collectExistingBucket(sub, doc, bucketOrd);
+                if (bucketOrd == 0 || bucketOrd == numClusters - 1) {
+                    // Only update average distance if the centroid of one of the end buckets is modifed.
+                    updateAvgBucketDistance();
+                }
             }
             return this;
         }
 
+        private void updateAvgBucketDistance() {
+            // Centroids are sorted so the average distance is the difference between the first and last.
+            avgBucketDistance = (clusterCentroids.get(numClusters - 1) - clusterCentroids.get(0)) / (numClusters - 1);
+        }
+
         /**
          * Creates a new cluster with  <code>value</code> and appends it to the cluster arrays
          */

+ 35 - 14
server/src/test/java/org/elasticsearch/search/aggregations/bucket/histogram/VariableWidthHistogramAggregatorTests.java

@@ -39,7 +39,6 @@ import org.elasticsearch.index.mapper.NumberFieldMapper;
 import org.elasticsearch.search.aggregations.AggregationBuilder;
 import org.elasticsearch.search.aggregations.AggregationBuilders;
 import org.elasticsearch.search.aggregations.AggregatorTestCase;
-import org.elasticsearch.search.aggregations.bucket.terms.InternalTerms;
 import org.elasticsearch.search.aggregations.bucket.terms.LongTerms;
 import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
 import org.elasticsearch.search.aggregations.metrics.InternalStats;
@@ -218,25 +217,25 @@ public class VariableWidthHistogramAggregatorTests extends AggregatorTestCase {
     // Once the cache limit is reached, cached documents are collected into (3/4 * shard_size) buckets
     // A new bucket should be added when there is a document that is distant from all existing buckets
     public void testNewBucketCreation() throws Exception {
-        final List<Number> dataset = Arrays.asList(-1, 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 40, 30, 25, 32, 38, 80, 50, 75);
+        final List<Number> dataset = Arrays.asList(-1, 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 40, 30, 25, 32, 36, 80, 50, 75, 60);
         double doubleError = 1d / 10000d;
 
         // Search (no reduce)
 
         // Expected clusters: [ (-1), (1), (3), (5), (7), (9), (11), (13), (15), (17),
-        //                      (19), (25, 30, 32), (38, 40), (50), (75, 80) ]
-        // Corresponding keys (centroids): [ -1, 1, 3, ..., 17, 19, 29, 39, 50, 77.5]
+        //                      (19), (25, 30, 32), (36, 40, 50), (60), (75, 80) ]
+        // Corresponding keys (centroids): [ -1, 1, 3, ..., 17, 19, 29, 42, 77.5]
         // Note: New buckets are created for 30, 50, and 80 because they are distant from the other buckets
-        final List<Double> keys = Arrays.asList(-1d, 1d, 3d, 5d, 7d, 9d, 11d, 13d, 15d, 17d, 19d, 29d, 39d, 50d, 77.5d);
-        final List<Double> mins = Arrays.asList(-1d, 1d, 3d, 5d, 7d, 9d, 11d, 13d, 15d, 17d, 19d, 25d, 38d, 50d, 75d);
-        final List<Double> maxes = Arrays.asList(-1d, 1d, 3d, 5d, 7d, 9d, 11d, 13d, 15d, 17d, 19d, 32d, 40d, 50d, 80d);
-        final List<Integer> docCounts = Arrays.asList(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 2, 1, 2);
+        final List<Double> keys = Arrays.asList(-1d, 1d, 3d, 5d, 7d, 9d, 11d, 13d, 15d, 17d, 19d, 29d, 42d, 60d, 77.5d);
+        final List<Double> mins = Arrays.asList(-1d, 1d, 3d, 5d, 7d, 9d, 11d, 13d, 15d, 17d, 19d, 25d, 36d, 60d, 75d);
+        final List<Double> maxes = Arrays.asList(-1d, 1d, 3d, 5d, 7d, 9d, 11d, 13d, 15d, 17d, 19d, 32d, 50d, 60d, 80d);
+        final List<Integer> docCounts = Arrays.asList(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 3, 1, 2);
         assert keys.size() == docCounts.size() && keys.size() == keys.size();
 
         final Map<Double, Integer> expectedDocCountOnlySearch = new HashMap<>();
         final Map<Double, Double> expectedMinsOnlySearch = new HashMap<>();
         final Map<Double, Double> expectedMaxesOnlySearch = new HashMap<>();
-        for(int i=0; i<keys.size(); i++){
+        for(int i=0; i < keys.size(); i++){
             expectedDocCountOnlySearch.put(keys.get(i), docCounts.get(i));
             expectedMinsOnlySearch.put(keys.get(i), mins.get(i));
             expectedMaxesOnlySearch.put(keys.get(i), maxes.get(i));
@@ -251,6 +250,31 @@ public class VariableWidthHistogramAggregatorTests extends AggregatorTestCase {
                     long expectedDocCount = expectedDocCountOnlySearch.getOrDefault(bucket.getKey(), 0).longValue();
                     double expectedCentroid = expectedMinsOnlySearch.getOrDefault(bucket.getKey(), 0d).doubleValue();
                     double expectedMax = expectedMaxesOnlySearch.getOrDefault(bucket.getKey(), 0d).doubleValue();
+                    assertEquals(bucket.getKeyAsString(), expectedDocCount, bucket.getDocCount());
+                    assertEquals(bucket.getKeyAsString(), expectedCentroid, bucket.min(), doubleError);
+                    assertEquals(bucket.getKeyAsString(), expectedMax, bucket.max(), doubleError);
+                });
+            });
+
+        // Rerun the test with very large keys which can cause an overflow
+        final Map<Double, Integer> expectedDocCountBigKeys = new HashMap<>();
+        final Map<Double, Double> expectedMinsBigKeys = new HashMap<>();
+        final Map<Double, Double> expectedMaxesBigKeys = new HashMap<>();
+        for(int i=0; i< keys.size(); i++){
+            expectedDocCountBigKeys.put(Long.MAX_VALUE * keys.get(i), docCounts.get(i));
+            expectedMinsBigKeys.put(Long.MAX_VALUE * keys.get(i), Long.MAX_VALUE * mins.get(i));
+            expectedMaxesBigKeys.put(Long.MAX_VALUE * keys.get(i), Long.MAX_VALUE * maxes.get(i));
+        }
+
+        testSearchCase(DEFAULT_QUERY, dataset.stream().map(n -> Double.valueOf(n.doubleValue() * Long.MAX_VALUE)).collect(toList()), false,
+            aggregation -> aggregation.field(NUMERIC_FIELD).setNumBuckets(2).setShardSize(16).setInitialBuffer(12),
+            histogram -> {
+                final List<InternalVariableWidthHistogram.Bucket> buckets = histogram.getBuckets();
+                assertEquals(expectedDocCountOnlySearch.size(), buckets.size());
+                buckets.forEach(bucket -> {
+                    long expectedDocCount = expectedDocCountBigKeys.getOrDefault(bucket.getKey(), 0).longValue();
+                    double expectedCentroid = expectedMinsBigKeys.getOrDefault(bucket.getKey(), 0d).doubleValue();
+                    double expectedMax = expectedMaxesBigKeys.getOrDefault(bucket.getKey(), 0d).doubleValue();
                     assertEquals(expectedDocCount, bucket.getDocCount());
                     assertEquals(expectedCentroid, bucket.min(), doubleError);
                     assertEquals(expectedMax, bucket.max(), doubleError);
@@ -308,7 +332,6 @@ public class VariableWidthHistogramAggregatorTests extends AggregatorTestCase {
                 .setShardSize(4)
                 .subAggregation(AggregationBuilders.stats("stats").field(NUMERIC_FIELD)),
             histogram -> {
-                final List<InternalVariableWidthHistogram.Bucket> buckets = histogram.getBuckets();
                 double deltaError = 1d/10000d;
 
                 // Expected clusters: [ (1, 2), (5), (8,9) ]
@@ -343,7 +366,6 @@ public class VariableWidthHistogramAggregatorTests extends AggregatorTestCase {
                 .setShardSize(4)
                 .subAggregation(new StatsAggregationBuilder("stats").field(NUMERIC_FIELD)),
             histogram -> {
-                final List<InternalVariableWidthHistogram.Bucket> buckets = histogram.getBuckets();
                 double deltaError = 1d / 10000d;
 
                 // Expected clusters: [ (0, 1, 2, 3), (4, 5, 6, 7), (8, 9, 10, 11) ]
@@ -381,16 +403,15 @@ public class VariableWidthHistogramAggregatorTests extends AggregatorTestCase {
                                         .shardSize(2)
                                         .size(1)),
             histogram -> {
-                final List<InternalVariableWidthHistogram.Bucket> buckets = histogram.getBuckets();
                 double deltaError = 1d / 10000d;
 
                 // This is a test to make sure that the sub aggregations get reduced
                 // This terms sub aggregation has shardSize (2) != size (1), so we will get 1 bucket only if
                 // InternalVariableWidthHistogram reduces the sub aggregations.
 
-                InternalTerms terms = histogram.getBuckets().get(0).getAggregations().get("terms");
+                LongTerms terms = histogram.getBuckets().get(0).getAggregations().get("terms");
                 assertEquals(1L, terms.getBuckets().size(), deltaError);
-                assertEquals(1L, ((InternalTerms.Bucket) terms.getBuckets().get(0)).getKey());
+                assertEquals(1L, terms.getBuckets().get(0).getKey());
             });
     }