Browse Source

Optimize geogrid aggregations for singleton points (#87290)

Use a different code path when we detect that the points on a segment
are all single valued which allow us to optimise the tight loop used
while iterating the doc values.

Checks on rally shows a modest, still nice performance improvement:

```
            
|                                                        Metric |        Task |       Baseline |      Contender |        Diff |   Unit |   Diff % |
|--------------------------------------------------------------:|------------:|---------------:|---------------:|------------:|-------:|---------:|
|                                                Min Throughput | geotilegrid |    0.573237    |    0.600875    |     0.02764 |  ops/s |   +4.82% |
|                                               Mean Throughput | geotilegrid |    0.607767    |    0.630151    |     0.02238 |  ops/s |   +3.68% |
|                                             Median Throughput | geotilegrid |    0.612317    |    0.633243    |     0.02093 |  ops/s |   +3.42% |
|                                                Max Throughput | geotilegrid |    0.623922    |    0.645545    |     0.02162 |  ops/s |   +3.47% |
|                                       50th percentile latency | geotilegrid | 1539.4         | 1492.69        |   -46.7073  |     ms |   -3.03% |
|                                       90th percentile latency | geotilegrid | 1570.21        | 1514.08        |   -56.1252  |     ms |   -3.57% |
|                                      100th percentile latency | geotilegrid | 1573.78        | 1515.79        |   -57.9968  |     ms |   -3.69% |
|                                  50th percentile service time | geotilegrid | 1539.4         | 1492.69        |   -46.7073  |     ms |   -3.03% |
|                                  90th percentile service time | geotilegrid | 1570.21        | 1514.08        |   -56.1252  |     ms |   -3.57% |
|                                 100th percentile service time | geotilegrid | 1573.78        | 1515.79        |   -57.9968  |     ms |   -3.69% |
|                                                    error rate | geotilegrid |    0           |    0           |     0       |      % |    0.00% |
|                                                Min Throughput | geohashgrid |    2.5986      |    2.99851     |     0.39992 |  ops/s |  +15.39% |
|                                               Mean Throughput | geohashgrid |    2.64153     |    3.02417     |     0.38264 |  ops/s |  +14.49% |
|                                             Median Throughput | geohashgrid |    2.65188     |    3.02704     |     0.37516 |  ops/s |  +14.15% |
|                                                Max Throughput | geohashgrid |    2.66263     |    3.03953     |     0.3769  |  ops/s |  +14.16% |
|                                       50th percentile latency | geohashgrid |  371.621       |  328.431       |   -43.19    |     ms |  -11.62% |
|                                       90th percentile latency | geohashgrid |  373.7         |  331.22        |   -42.4795  |     ms |  -11.37% |
|                                      100th percentile latency | geohashgrid |  374.082       |  332.37        |   -41.712   |     ms |  -11.15% |
|                                  50th percentile service time | geohashgrid |  371.621       |  328.431       |   -43.19    |     ms |  -11.62% |
|                                  90th percentile service time | geohashgrid |  373.7         |  331.22        |   -42.4795  |     ms |  -11.37% |
|                                 100th percentile service time | geohashgrid |  374.082       |  332.37        |   -41.712   |     ms |  -11.15% |
|                                                    error rate | geohashgrid |    0           |    0           |     0       |      % |    0.00% |
|                                                Min Throughput |  geohexgrid |    0.125189    |    0.132832    |     0.00764 |  ops/s |   +6.11% |
|                                               Mean Throughput |  geohexgrid |    0.127163    |    0.133508    |     0.00634 |  ops/s |   +4.99% |
|                                             Median Throughput |  geohexgrid |    0.126976    |    0.133546    |     0.00657 |  ops/s |   +5.17% |
|                                                Max Throughput |  geohexgrid |    0.131959    |    0.134025    |     0.00207 |  ops/s |   +1.57% |
|                                       50th percentile latency |  geohexgrid | 7802.83        | 7426.77        |  -376.056   |     ms |   -4.82% |
|                                       90th percentile latency |  geohexgrid | 8501.63        | 7552.32        |  -949.308   |     ms |  -11.17% |
|                                      100th percentile latency |  geohexgrid | 9058.2         | 7726.13        | -1332.07    |     ms |  -14.71% |
|                                  50th percentile service time |  geohexgrid | 7802.83        | 7426.77        |  -376.056   |     ms |   -4.82% |
|                                  90th percentile service time |  geohexgrid | 8501.63        | 7552.32        |  -949.308   |     ms |  -11.17% |
|                                 100th percentile service time |  geohexgrid | 9058.2         | 7726.13        | -1332.07    |     ms |  -14.71% |
|                                                    error rate |  geohexgrid |    0           |    0           |     0       |      % |    0.00% |
```
Ignacio Vera 3 năm trước cách đây
mục cha
commit
30406baabf

+ 5 - 0
docs/changelog/87290.yaml

@@ -0,0 +1,5 @@
+pr: 87290
+summary: Optimize geogrid aggregations for singleton points
+area: Geo
+type: enhancement
+issues: []

+ 109 - 7
server/src/main/java/org/elasticsearch/search/aggregations/bucket/geogrid/CellIdSource.java

@@ -8,16 +8,24 @@
 
 package org.elasticsearch.search.aggregations.bucket.geogrid;
 
+import org.apache.lucene.index.DocValues;
 import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.NumericDocValues;
 import org.apache.lucene.index.SortedNumericDocValues;
 import org.elasticsearch.common.geo.GeoBoundingBox;
+import org.elasticsearch.index.fielddata.AbstractNumericDocValues;
+import org.elasticsearch.index.fielddata.AbstractSortingNumericDocValues;
+import org.elasticsearch.index.fielddata.GeoPointValues;
 import org.elasticsearch.index.fielddata.MultiGeoPointValues;
 import org.elasticsearch.index.fielddata.SortedBinaryDocValues;
 import org.elasticsearch.index.fielddata.SortedNumericDoubleValues;
 import org.elasticsearch.search.aggregations.support.ValuesSource;
 
+import java.io.IOException;
+
 /**
- * Base class to help convert {@link MultiGeoPointValues} to {@link CellValues}
+ * Base class to help convert {@link MultiGeoPointValues} to {@link CellMultiValues}
+ * and {@link GeoPointValues} to {@link CellSingleValue}
  */
 public abstract class CellIdSource extends ValuesSource.Numeric {
 
@@ -45,22 +53,35 @@ public abstract class CellIdSource extends ValuesSource.Numeric {
     @Override
     public final SortedNumericDocValues longValues(LeafReaderContext ctx) {
         final MultiGeoPointValues multiGeoPointValues = valuesSource.geoPointValues(ctx);
+        final GeoPointValues values = org.elasticsearch.index.fielddata.FieldData.unwrapSingleton(multiGeoPointValues);
         if (geoBoundingBox.isUnbounded()) {
-            return unboundedCellValues(multiGeoPointValues);
+            return values == null ? unboundedCellMultiValues(multiGeoPointValues) : DocValues.singleton(unboundedCellSingleValue(values));
         } else {
-            return boundedCellValues(multiGeoPointValues, geoBoundingBox);
+            return values == null
+                ? boundedCellMultiValues(multiGeoPointValues, geoBoundingBox)
+                : DocValues.singleton(boundedCellSingleValue(values, geoBoundingBox));
         }
     }
 
     /**
-     * Generate an unbounded iterator of grid-cells
+     * Generate an unbounded iterator of grid-cells for singleton case.
+     */
+    protected abstract NumericDocValues unboundedCellSingleValue(GeoPointValues values);
+
+    /**
+     * Generate a bounded iterator of grid-cells for singleton case.
      */
-    protected abstract CellValues unboundedCellValues(MultiGeoPointValues values);
+    protected abstract NumericDocValues boundedCellSingleValue(GeoPointValues values, GeoBoundingBox boundingBox);
 
     /**
-     * Generate a bounded iterator of grid-cells
+     * Generate an unbounded iterator of grid-cells for multi-value case.
      */
-    protected abstract CellValues boundedCellValues(MultiGeoPointValues values, GeoBoundingBox boundingBox);
+    protected abstract SortedNumericDocValues unboundedCellMultiValues(MultiGeoPointValues values);
+
+    /**
+     * Generate a bounded iterator of grid-cells for multi-value case.
+     */
+    protected abstract SortedNumericDocValues boundedCellMultiValues(MultiGeoPointValues values, GeoBoundingBox boundingBox);
 
     @Override
     public final SortedNumericDoubleValues doubleValues(LeafReaderContext ctx) {
@@ -89,4 +110,85 @@ public abstract class CellIdSource extends ValuesSource.Numeric {
         return false;
     }
 
+    /**
+     * Class representing the long-encoded grid-cells belonging to
+     * the multi-value geo-doc-values. Class must encode the values and then
+     * sort them in order to account for the cells correctly.
+     */
+    protected abstract static class CellMultiValues extends AbstractSortingNumericDocValues {
+        private final MultiGeoPointValues geoValues;
+        protected final int precision;
+
+        protected CellMultiValues(MultiGeoPointValues geoValues, int precision) {
+            this.geoValues = geoValues;
+            this.precision = precision;
+        }
+
+        @Override
+        public boolean advanceExact(int docId) throws IOException {
+            if (geoValues.advanceExact(docId)) {
+                int docValueCount = geoValues.docValueCount();
+                resize(docValueCount);
+                int j = 0;
+                for (int i = 0; i < docValueCount; i++) {
+                    j = advanceValue(geoValues.nextValue(), j);
+                }
+                resize(j);
+                sort();
+                return true;
+            } else {
+                return false;
+            }
+        }
+
+        /**
+         * Sets the appropriate long-encoded value for <code>target</code>
+         * in <code>values</code>.
+         *
+         * @param target    the geo-value to encode
+         * @param valuesIdx the index into <code>values</code> to set
+         * @return          valuesIdx + 1 if value was set, valuesIdx otherwise.
+         */
+        protected abstract int advanceValue(org.elasticsearch.common.geo.GeoPoint target, int valuesIdx);
+    }
+
+    /**
+     * Class representing the long-encoded grid-cells belonging to
+     * the singleton geo-doc-values.
+     */
+    protected abstract static class CellSingleValue extends AbstractNumericDocValues {
+        private final GeoPointValues geoValues;
+        protected final int precision;
+        protected long value;
+
+        protected CellSingleValue(GeoPointValues geoValues, int precision) {
+            this.geoValues = geoValues;
+            this.precision = precision;
+
+        }
+
+        @Override
+        public boolean advanceExact(int docId) throws IOException {
+            return geoValues.advanceExact(docId) && advance(geoValues.geoPointValue());
+        }
+
+        @Override
+        public long longValue() throws IOException {
+            return value;
+        }
+
+        /**
+         * Sets the appropriate long-encoded value for <code>target</code>
+         * in <code>value</code>.
+         *
+         * @param target    the geo-value to encode
+         * @return          true if the value needs to be added, otherwise false.
+         */
+        protected abstract boolean advance(org.elasticsearch.common.geo.GeoPoint target);
+
+        @Override
+        public int docID() {
+            return -1;
+        }
+    }
 }

+ 0 - 55
server/src/main/java/org/elasticsearch/search/aggregations/bucket/geogrid/CellValues.java

@@ -1,55 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0 and the Server Side Public License, v 1; you may not use this file except
- * in compliance with, at your election, the Elastic License 2.0 or the Server
- * Side Public License, v 1.
- */
-package org.elasticsearch.search.aggregations.bucket.geogrid;
-
-import org.elasticsearch.index.fielddata.AbstractSortingNumericDocValues;
-import org.elasticsearch.index.fielddata.MultiGeoPointValues;
-
-import java.io.IOException;
-
-/**
- * Class representing the long-encoded grid-cells belonging to
- * the geo-doc-values. Class must encode the values and then
- * sort them in order to account for the cells correctly.
- */
-public abstract class CellValues extends AbstractSortingNumericDocValues {
-    private MultiGeoPointValues geoValues;
-    protected int precision;
-
-    protected CellValues(MultiGeoPointValues geoValues, int precision) {
-        this.geoValues = geoValues;
-        this.precision = precision;
-    }
-
-    @Override
-    public boolean advanceExact(int docId) throws IOException {
-        if (geoValues.advanceExact(docId)) {
-            int docValueCount = geoValues.docValueCount();
-            resize(docValueCount);
-            int j = 0;
-            for (int i = 0; i < docValueCount; i++) {
-                j = advanceValue(geoValues.nextValue(), j);
-            }
-            resize(j);
-            sort();
-            return true;
-        } else {
-            return false;
-        }
-    }
-
-    /**
-     * Sets the appropriate long-encoded value for <code>target</code>
-     * in <code>values</code>.
-     *
-     * @param target    the geo-value to encode
-     * @param valuesIdx the index into <code>values</code> to set
-     * @return          valuesIdx + 1 if value was set, valuesIdx otherwise.
-     */
-    protected abstract int advanceValue(org.elasticsearch.common.geo.GeoPoint target, int valuesIdx);
-}

+ 27 - 2
server/src/main/java/org/elasticsearch/search/aggregations/bucket/geogrid/GeoGridAggregator.java

@@ -7,7 +7,9 @@
  */
 package org.elasticsearch.search.aggregations.bucket.geogrid;
 
+import org.apache.lucene.index.DocValues;
 import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.NumericDocValues;
 import org.apache.lucene.index.SortedNumericDocValues;
 import org.apache.lucene.search.ScoreMode;
 import org.elasticsearch.core.Releasables;
@@ -65,8 +67,31 @@ public abstract class GeoGridAggregator<T extends InternalGeoGrid<?>> extends Bu
     }
 
     @Override
-    public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException {
-        SortedNumericDocValues values = valuesSource.longValues(ctx);
+    public LeafBucketCollector getLeafCollector(final LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException {
+        final SortedNumericDocValues values = valuesSource.longValues(ctx);
+        final NumericDocValues singleton = DocValues.unwrapSingleton(values);
+        return singleton != null ? getLeafCollector(singleton, sub) : getLeafCollector(values, sub);
+    }
+
+    private LeafBucketCollector getLeafCollector(final NumericDocValues values, final LeafBucketCollector sub) {
+        return new LeafBucketCollectorBase(sub, null) {
+            @Override
+            public void collect(int doc, long owningBucketOrd) throws IOException {
+                if (values.advanceExact(doc)) {
+                    final long val = values.longValue();
+                    long bucketOrdinal = bucketOrds.add(owningBucketOrd, val);
+                    if (bucketOrdinal < 0) { // already seen
+                        bucketOrdinal = -1 - bucketOrdinal;
+                        collectExistingBucket(sub, doc, bucketOrdinal);
+                    } else {
+                        collectBucket(sub, doc, bucketOrdinal);
+                    }
+                }
+            }
+        };
+    }
+
+    private LeafBucketCollector getLeafCollector(final SortedNumericDocValues values, final LeafBucketCollector sub) {
         return new LeafBucketCollectorBase(sub, null) {
             @Override
             public void collect(int doc, long owningBucketOrd) throws IOException {

+ 35 - 5
server/src/main/java/org/elasticsearch/search/aggregations/bucket/geogrid/GeoHashCellIdSource.java

@@ -7,13 +7,16 @@
  */
 package org.elasticsearch.search.aggregations.bucket.geogrid;
 
+import org.apache.lucene.index.NumericDocValues;
+import org.apache.lucene.index.SortedNumericDocValues;
 import org.elasticsearch.common.geo.GeoBoundingBox;
 import org.elasticsearch.geometry.utils.Geohash;
+import org.elasticsearch.index.fielddata.GeoPointValues;
 import org.elasticsearch.index.fielddata.MultiGeoPointValues;
 import org.elasticsearch.search.aggregations.support.ValuesSource;
 
 /**
- * Class to help convert {@link MultiGeoPointValues} to Geohash {@link CellValues}
+ * {@link CellIdSource} implementation for Geohash aggregation
  */
 public class GeoHashCellIdSource extends CellIdSource {
 
@@ -22,8 +25,35 @@ public class GeoHashCellIdSource extends CellIdSource {
     }
 
     @Override
-    protected CellValues unboundedCellValues(MultiGeoPointValues values) {
-        return new CellValues(values, precision()) {
+    protected NumericDocValues unboundedCellSingleValue(GeoPointValues values) {
+        return new CellSingleValue(values, precision()) {
+            @Override
+            protected boolean advance(org.elasticsearch.common.geo.GeoPoint target) {
+                value = Geohash.longEncode(target.getLon(), target.getLat(), precision);
+                return true;
+            }
+        };
+    }
+
+    @Override
+    protected NumericDocValues boundedCellSingleValue(GeoPointValues values, GeoBoundingBox boundingBox) {
+        final GeoHashBoundedPredicate predicate = new GeoHashBoundedPredicate(precision(), boundingBox);
+        return new CellSingleValue(values, precision()) {
+            @Override
+            protected boolean advance(org.elasticsearch.common.geo.GeoPoint target) {
+                final String hash = Geohash.stringEncode(target.getLon(), target.getLat(), precision);
+                if (validPoint(target.getLon(), target.getLat()) || predicate.validHash(hash)) {
+                    value = Geohash.longEncode(hash);
+                    return true;
+                }
+                return false;
+            }
+        };
+    }
+
+    @Override
+    protected SortedNumericDocValues unboundedCellMultiValues(MultiGeoPointValues values) {
+        return new CellMultiValues(values, precision()) {
             @Override
             protected int advanceValue(org.elasticsearch.common.geo.GeoPoint target, int valuesIdx) {
                 values[valuesIdx] = Geohash.longEncode(target.getLon(), target.getLat(), precision);
@@ -33,9 +63,9 @@ public class GeoHashCellIdSource extends CellIdSource {
     }
 
     @Override
-    protected CellValues boundedCellValues(MultiGeoPointValues values, GeoBoundingBox boundingBox) {
+    protected SortedNumericDocValues boundedCellMultiValues(MultiGeoPointValues values, GeoBoundingBox boundingBox) {
         final GeoHashBoundedPredicate predicate = new GeoHashBoundedPredicate(precision(), boundingBox);
-        return new CellValues(values, precision()) {
+        return new CellMultiValues(values, precision()) {
             @Override
             protected int advanceValue(org.elasticsearch.common.geo.GeoPoint target, int valuesIdx) {
                 final String hash = Geohash.stringEncode(target.getLon(), target.getLat(), precision);

+ 37 - 5
server/src/main/java/org/elasticsearch/search/aggregations/bucket/geogrid/GeoTileCellIdSource.java

@@ -7,11 +7,14 @@
  */
 package org.elasticsearch.search.aggregations.bucket.geogrid;
 
+import org.apache.lucene.index.NumericDocValues;
+import org.apache.lucene.index.SortedNumericDocValues;
 import org.elasticsearch.common.geo.GeoBoundingBox;
+import org.elasticsearch.index.fielddata.GeoPointValues;
 import org.elasticsearch.index.fielddata.MultiGeoPointValues;
 
 /**
- * Class to help convert {@link MultiGeoPointValues} to GeoTile {@link CellValues}
+ * {@link CellIdSource} implementation for GeoTile aggregation
  */
 public class GeoTileCellIdSource extends CellIdSource {
 
@@ -20,8 +23,37 @@ public class GeoTileCellIdSource extends CellIdSource {
     }
 
     @Override
-    protected CellValues unboundedCellValues(MultiGeoPointValues values) {
-        return new CellValues(values, precision()) {
+    protected NumericDocValues unboundedCellSingleValue(GeoPointValues values) {
+        return new CellSingleValue(values, precision()) {
+            @Override
+            protected boolean advance(org.elasticsearch.common.geo.GeoPoint target) {
+                value = GeoTileUtils.longEncode(target.getLon(), target.getLat(), precision);
+                return true;
+            }
+        };
+    }
+
+    @Override
+    protected NumericDocValues boundedCellSingleValue(GeoPointValues values, GeoBoundingBox boundingBox) {
+        final GeoTileBoundedPredicate predicate = new GeoTileBoundedPredicate(precision(), boundingBox);
+        final long tiles = 1L << precision();
+        return new CellSingleValue(values, precision()) {
+            @Override
+            protected boolean advance(org.elasticsearch.common.geo.GeoPoint target) {
+                final int x = GeoTileUtils.getXTile(target.getLon(), tiles);
+                final int y = GeoTileUtils.getYTile(target.getLat(), tiles);
+                if (predicate.validTile(x, y, precision)) {
+                    value = GeoTileUtils.longEncodeTiles(precision, x, y);
+                    return true;
+                }
+                return false;
+            }
+        };
+    }
+
+    @Override
+    protected SortedNumericDocValues unboundedCellMultiValues(MultiGeoPointValues values) {
+        return new CellMultiValues(values, precision()) {
             @Override
             protected int advanceValue(org.elasticsearch.common.geo.GeoPoint target, int valuesIdx) {
                 values[valuesIdx] = GeoTileUtils.longEncode(target.getLon(), target.getLat(), precision);
@@ -31,10 +63,10 @@ public class GeoTileCellIdSource extends CellIdSource {
     }
 
     @Override
-    protected CellValues boundedCellValues(MultiGeoPointValues values, GeoBoundingBox boundingBox) {
+    protected SortedNumericDocValues boundedCellMultiValues(MultiGeoPointValues values, GeoBoundingBox boundingBox) {
         final GeoTileBoundedPredicate predicate = new GeoTileBoundedPredicate(precision(), boundingBox);
         final long tiles = 1L << precision();
-        return new CellValues(values, precision()) {
+        return new CellMultiValues(values, precision()) {
             @Override
             protected int advanceValue(org.elasticsearch.common.geo.GeoPoint target, int valuesIdx) {
                 final int x = GeoTileUtils.getXTile(target.getLon(), tiles);

+ 94 - 83
test/framework/src/main/java/org/elasticsearch/search/aggregations/bucket/geogrid/GeoGridAggregatorTestCase.java

@@ -19,6 +19,7 @@ import org.apache.lucene.search.MatchAllDocsQuery;
 import org.apache.lucene.search.Query;
 import org.apache.lucene.store.Directory;
 import org.apache.lucene.tests.index.RandomIndexWriter;
+import org.apache.lucene.tests.util.LuceneTestCase;
 import org.apache.lucene.util.BytesRef;
 import org.elasticsearch.common.geo.GeoBoundingBox;
 import org.elasticsearch.core.CheckedConsumer;
@@ -47,8 +48,8 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.TreeMap;
+import java.util.function.BooleanSupplier;
 import java.util.function.Consumer;
-import java.util.function.Function;
 
 import static org.hamcrest.Matchers.equalTo;
 
@@ -133,16 +134,34 @@ public abstract class GeoGridAggregatorTestCase<T extends InternalGeoGridBucket>
 
     }
 
-    public void testWithSeveralDocs() throws IOException {
+    public void testSingletonDocs() throws IOException {
+        testWithSeveralDocs(() -> true, null);
+    }
+
+    public void testBoundedSingletonDocs() throws IOException {
+        testWithSeveralDocs(() -> true, randomBBox());
+    }
+
+    public void testMultiValuedDocs() throws IOException {
+        testWithSeveralDocs(LuceneTestCase::rarely, null);
+    }
+
+    public void testBoundedMultiValuedDocs() throws IOException {
+        testWithSeveralDocs(LuceneTestCase::rarely, randomBBox());
+    }
+
+    private void testWithSeveralDocs(BooleanSupplier supplier, GeoBoundingBox bbox) throws IOException {
         int precision = randomPrecision();
         int numPoints = randomIntBetween(8, 128);
         Map<String, Integer> expectedCountPerGeoHash = new HashMap<>();
-        testCase(new MatchAllDocsQuery(), FIELD_NAME, precision, null, geoHashGrid -> {
+        testCase(new MatchAllDocsQuery(), FIELD_NAME, precision, bbox, geoHashGrid -> {
             assertEquals(expectedCountPerGeoHash.size(), geoHashGrid.getBuckets().size());
             for (GeoGrid.Bucket bucket : geoHashGrid.getBuckets()) {
                 assertEquals((long) expectedCountPerGeoHash.get(bucket.getKeyAsString()), bucket.getDocCount());
             }
-            assertTrue(AggregationInspectionHelper.hasValue(geoHashGrid));
+            if (bbox == null) {
+                assertTrue(AggregationInspectionHelper.hasValue(geoHashGrid));
+            }
         }, iw -> {
             List<LatLonDocValuesField> points = new ArrayList<>();
             Set<String> distinctHashesPerDoc = new HashSet<>();
@@ -150,11 +169,14 @@ public abstract class GeoGridAggregatorTestCase<T extends InternalGeoGridBucket>
                 double[] latLng = randomLatLng();
                 points.add(new LatLonDocValuesField(FIELD_NAME, latLng[0], latLng[1]));
                 String hash = hashAsString(latLng[1], latLng[0], precision);
-                if (distinctHashesPerDoc.contains(hash) == false) {
-                    expectedCountPerGeoHash.put(hash, expectedCountPerGeoHash.getOrDefault(hash, 0) + 1);
+                Rectangle bin = getTile(latLng[1], latLng[0], precision);
+                if (intersectsBounds(bin, bbox) || validPoint(latLng[1], latLng[0], bbox)) {
+                    if (distinctHashesPerDoc.contains(hash) == false) {
+                        expectedCountPerGeoHash.put(hash, expectedCountPerGeoHash.getOrDefault(hash, 0) + 1);
+                    }
+                    distinctHashesPerDoc.add(hash);
                 }
-                distinctHashesPerDoc.add(hash);
-                if (usually()) {
+                if (supplier.getAsBoolean()) {
                     iw.addDocument(points);
                     points.clear();
                     distinctHashesPerDoc.clear();
@@ -166,35 +188,64 @@ public abstract class GeoGridAggregatorTestCase<T extends InternalGeoGridBucket>
         });
     }
 
-    public void testAsSubAgg() throws IOException {
+    public void testSingletonDocsAsSubAgg() throws IOException {
+        testWithSeveralDocsAsSubAgg(() -> true, null);
+    }
+
+    public void testBoundedSingletonDocsAsSubAgg() throws IOException {
+        testWithSeveralDocsAsSubAgg(() -> true, randomBBox());
+    }
+
+    public void testMultiValuedDocsAsSubAgg() throws IOException {
+        testWithSeveralDocsAsSubAgg(LuceneTestCase::rarely, null);
+    }
+
+    public void testBoundedMultiValuedDocsAsSubAgg() throws IOException {
+        testWithSeveralDocsAsSubAgg(LuceneTestCase::rarely, randomBBox());
+    }
+
+    private void testWithSeveralDocsAsSubAgg(BooleanSupplier supplier, GeoBoundingBox bbox) throws IOException {
         int precision = randomPrecision();
+        int numPoints = randomIntBetween(8, 128);
         Map<String, Map<String, Long>> expectedCountPerTPerGeoHash = new TreeMap<>();
-        List<List<IndexableField>> docs = new ArrayList<>();
-        for (int i = 0; i < 30; i++) {
+        TermsAggregationBuilder aggregationBuilder = new TermsAggregationBuilder("t").field("t").size(numPoints);
+        GeoGridAggregationBuilder gridBuilder = createBuilder("gg").field(FIELD_NAME).precision(precision);
+        if (bbox != null) {
+            gridBuilder.setGeoBoundingBox(bbox);
+        }
+        aggregationBuilder.subAggregation(gridBuilder);
+        testCase(aggregationBuilder, new MatchAllDocsQuery(), iw -> {
+            List<IndexableField> fields = new ArrayList<>();
+            Set<String> distinctHashesPerDoc = new HashSet<>();
             String t = randomAlphaOfLength(1);
-            double[] latLng = randomLatLng();
-
-            List<IndexableField> doc = new ArrayList<>();
-            docs.add(doc);
-            doc.add(new LatLonDocValuesField(FIELD_NAME, latLng[0], latLng[1]));
-            doc.add(new SortedSetDocValuesField("t", new BytesRef(t)));
-            doc.add(new Field("t", new BytesRef(t), KeywordFieldMapper.Defaults.FIELD_TYPE));
-
-            String hash = hashAsString(latLng[1], latLng[0], precision);
-            Map<String, Long> expectedCountPerGeoHash = expectedCountPerTPerGeoHash.get(t);
-            if (expectedCountPerGeoHash == null) {
-                expectedCountPerGeoHash = new TreeMap<>();
-                expectedCountPerTPerGeoHash.put(t, expectedCountPerGeoHash);
+            for (int pointId = 0; pointId < numPoints; pointId++) {
+                Map<String, Long> expectedCountPerGeoHash = expectedCountPerTPerGeoHash.computeIfAbsent(t, k -> new TreeMap<>());
+                double[] latLng = randomLatLng();
+                fields.add(new LatLonDocValuesField(FIELD_NAME, latLng[0], latLng[1]));
+                String hash = hashAsString(latLng[1], latLng[0], precision);
+                if (distinctHashesPerDoc.contains(hash) == false) {
+                    if (intersectsBounds(getTile(latLng[1], latLng[0], precision), bbox) || validPoint(latLng[1], latLng[0], bbox)) {
+                        expectedCountPerGeoHash.put(hash, expectedCountPerGeoHash.getOrDefault(hash, 0L) + 1);
+                        distinctHashesPerDoc.add(hash);
+                    }
+                }
+                if (supplier.getAsBoolean()) {
+                    fields.add(new SortedSetDocValuesField("t", new BytesRef(t)));
+                    fields.add(new Field("t", new BytesRef(t), KeywordFieldMapper.Defaults.FIELD_TYPE));
+                    iw.addDocument(fields);
+                    fields.clear();
+                    distinctHashesPerDoc.clear();
+                    t = randomAlphaOfLength(1);
+                }
             }
-            expectedCountPerGeoHash.put(hash, expectedCountPerGeoHash.getOrDefault(hash, 0L) + 1);
-        }
-        CheckedConsumer<RandomIndexWriter, IOException> buildIndex = iw -> iw.addDocuments(docs);
-        TermsAggregationBuilder aggregationBuilder = new TermsAggregationBuilder("t").field("t")
-            .size(expectedCountPerTPerGeoHash.size())
-            .subAggregation(createBuilder("gg").field(FIELD_NAME).precision(precision));
-        Consumer<StringTerms> verify = (terms) -> {
+            if (fields.size() != 0) {
+                fields.add(new SortedSetDocValuesField("t", new BytesRef(t)));
+                fields.add(new Field("t", new BytesRef(t), KeywordFieldMapper.Defaults.FIELD_TYPE));
+                iw.addDocument(fields);
+            }
+        }, terms -> {
             Map<String, Map<String, Long>> actual = new TreeMap<>();
-            for (StringTerms.Bucket tb : terms.getBuckets()) {
+            for (StringTerms.Bucket tb : ((StringTerms) terms).getBuckets()) {
                 InternalGeoGrid<?> gg = tb.getAggregations().get("gg");
                 Map<String, Long> sub = new TreeMap<>();
                 for (InternalGeoGridBucket ggb : gg.getBuckets()) {
@@ -203,69 +254,26 @@ public abstract class GeoGridAggregatorTestCase<T extends InternalGeoGridBucket>
                 actual.put(tb.getKeyAsString(), sub);
             }
             assertThat(actual, equalTo(expectedCountPerTPerGeoHash));
-        };
-        testCase(aggregationBuilder, new MatchAllDocsQuery(), buildIndex, verify, keywordField("t"), geoPointField(FIELD_NAME));
+        }, keywordField("t"), geoPointField(FIELD_NAME));
     }
 
     private double[] randomLatLng() {
-        double lat = (180d * randomDouble()) - 90d;
-        double lng = (360d * randomDouble()) - 180d;
+        Point point = randomPoint();
 
         // Precision-adjust longitude/latitude to avoid wrong bucket placement
         // Internally, lat/lng get converted to 32 bit integers, loosing some precision.
         // This does not affect geohashing because geohash uses the same algorithm,
         // but it does affect other bucketing algos, thus we need to do the same steps here.
-        lng = GeoEncodingUtils.decodeLongitude(GeoEncodingUtils.encodeLongitude(lng));
-        lat = GeoEncodingUtils.decodeLatitude(GeoEncodingUtils.encodeLatitude(lat));
+        double lon = GeoEncodingUtils.decodeLongitude(GeoEncodingUtils.encodeLongitude(point.getLon()));
+        double lat = GeoEncodingUtils.decodeLatitude(GeoEncodingUtils.encodeLatitude(point.getLat()));
 
-        return new double[] { lat, lng };
-    }
-
-    public void testBounds() throws IOException {
-        final int numDocs = randomIntBetween(64, 256);
-        final GeoGridAggregationBuilder builder = createBuilder("_name");
-
-        expectThrows(IllegalArgumentException.class, () -> builder.precision(-1));
-        expectThrows(IllegalArgumentException.class, () -> builder.precision(30));
-
-        GeoBoundingBox bbox = randomBBox();
-
-        Function<Double, Double> encodeDecodeLat = (lat) -> GeoEncodingUtils.decodeLatitude(GeoEncodingUtils.encodeLatitude(lat));
-        Function<Double, Double> encodeDecodeLon = (lon) -> GeoEncodingUtils.decodeLongitude(GeoEncodingUtils.encodeLongitude(lon));
-        final int precision = randomPrecision();
-        int in = 0;
-        List<LatLonDocValuesField> docs = new ArrayList<>();
-        for (int i = 0; i < numDocs; i++) {
-            Point p = randomPoint();
-            double x = encodeDecodeLon.apply(p.getLon());
-            double y = encodeDecodeLat.apply(p.getLat());
-            Rectangle pointTile = getTile(x, y, precision);
-            if (intersectsBounds(pointTile, bbox) || validPoint(x, y, bbox)) {
-                in++;
-            }
-            docs.add(new LatLonDocValuesField(FIELD_NAME, p.getLat(), p.getLon()));
-        }
-
-        final long numDocsInBucket = in;
-        testCase(new MatchAllDocsQuery(), FIELD_NAME, precision, bbox, geoGrid -> {
-            if (numDocsInBucket > 0) {
-                assertTrue(AggregationInspectionHelper.hasValue(geoGrid));
-                long docCount = 0;
-                for (int i = 0; i < geoGrid.getBuckets().size(); i++) {
-                    docCount += geoGrid.getBuckets().get(i).getDocCount();
-                }
-                assertThat(docCount, equalTo(numDocsInBucket));
-            } else {
-                assertFalse(AggregationInspectionHelper.hasValue(geoGrid));
-            }
-        }, iw -> {
-            for (LatLonDocValuesField docField : docs) {
-                iw.addDocument(Collections.singletonList(docField));
-            }
-        });
+        return new double[] { lat, lon };
     }
 
     private boolean validPoint(double x, double y, GeoBoundingBox bbox) {
+        if (bbox == null) {
+            return true;
+        }
         if (bbox.top() > y && bbox.bottom() < y) {
             boolean crossesDateline = bbox.left() > bbox.right();
             if (crossesDateline) {
@@ -278,6 +286,9 @@ public abstract class GeoGridAggregatorTestCase<T extends InternalGeoGridBucket>
     }
 
     private boolean intersectsBounds(Rectangle pointTile, GeoBoundingBox bbox) {
+        if (bbox == null) {
+            return true;
+        }
         if (pointTile.getMinX() > pointTile.getMaxX()) {
             Rectangle right = new Rectangle(pointTile.getMinX(), 180, pointTile.getMaxY(), pointTile.getMinY());
             Rectangle left = new Rectangle(-180, pointTile.getMaxX(), pointTile.getMaxY(), pointTile.getMinY());

+ 38 - 6
x-pack/plugin/spatial/src/main/java/org/elasticsearch/xpack/spatial/search/aggregations/bucket/geogrid/GeoHexCellIdSource.java

@@ -6,15 +6,17 @@
  */
 package org.elasticsearch.xpack.spatial.search.aggregations.bucket.geogrid;
 
+import org.apache.lucene.index.NumericDocValues;
+import org.apache.lucene.index.SortedNumericDocValues;
 import org.elasticsearch.common.geo.GeoBoundingBox;
 import org.elasticsearch.h3.CellBoundary;
 import org.elasticsearch.h3.H3;
+import org.elasticsearch.index.fielddata.GeoPointValues;
 import org.elasticsearch.index.fielddata.MultiGeoPointValues;
 import org.elasticsearch.search.aggregations.bucket.geogrid.CellIdSource;
-import org.elasticsearch.search.aggregations.bucket.geogrid.CellValues;
 
 /**
-* Class to help convert {@link MultiGeoPointValues} to GeoHex {@link CellValues}
+* {@link CellIdSource} implementation for GeoHex aggregation
 */
 public class GeoHexCellIdSource extends CellIdSource {
 
@@ -23,8 +25,38 @@ public class GeoHexCellIdSource extends CellIdSource {
     }
 
     @Override
-    protected CellValues unboundedCellValues(MultiGeoPointValues values) {
-        return new CellValues(values, precision()) {
+    protected NumericDocValues unboundedCellSingleValue(GeoPointValues values) {
+        return new CellSingleValue(values, precision()) {
+            @Override
+            protected boolean advance(org.elasticsearch.common.geo.GeoPoint target) {
+                value = H3.geoToH3(target.getLat(), target.getLon(), precision);
+                return true;
+            }
+        };
+    }
+
+    @Override
+    protected NumericDocValues boundedCellSingleValue(GeoPointValues values, GeoBoundingBox boundingBox) {
+        final GeoHexPredicate predicate = new GeoHexPredicate(boundingBox, precision());
+        return new CellSingleValue(values, precision()) {
+            @Override
+            protected boolean advance(org.elasticsearch.common.geo.GeoPoint target) {
+                final double lat = target.getLat();
+                final double lon = target.getLon();
+                final long hex = H3.geoToH3(lat, lon, precision);
+                // validPoint is a fast check, validHex is slow
+                if (validPoint(lon, lat) || predicate.validHex(hex)) {
+                    value = hex;
+                    return true;
+                }
+                return false;
+            }
+        };
+    }
+
+    @Override
+    protected SortedNumericDocValues unboundedCellMultiValues(MultiGeoPointValues values) {
+        return new CellMultiValues(values, precision()) {
             @Override
             protected int advanceValue(org.elasticsearch.common.geo.GeoPoint target, int valuesIdx) {
                 values[valuesIdx] = H3.geoToH3(target.getLat(), target.getLon(), precision);
@@ -34,9 +66,9 @@ public class GeoHexCellIdSource extends CellIdSource {
     }
 
     @Override
-    protected CellValues boundedCellValues(MultiGeoPointValues values, GeoBoundingBox boundingBox) {
+    protected SortedNumericDocValues boundedCellMultiValues(MultiGeoPointValues values, GeoBoundingBox boundingBox) {
         final GeoHexPredicate predicate = new GeoHexPredicate(boundingBox, precision());
-        return new CellValues(values, precision()) {
+        return new CellMultiValues(values, precision()) {
             @Override
             protected int advanceValue(org.elasticsearch.common.geo.GeoPoint target, int valuesIdx) {
                 final double lat = target.getLat();