Browse Source

Improve accuracy for Geo Centroid Aggregation (#41033)

keeps the partial results as doubles and uses Kahan summation to help reduce floating point errors.
Ignacio Vera 6 years ago
parent
commit
cc48427e05

+ 5 - 5
docs/reference/aggregations/metrics/geocentroid-aggregation.asciidoc

@@ -58,8 +58,8 @@ The response for the above aggregation:
     "aggregations": {
         "centroid": {
             "location": {
-                "lat": 51.009829603135586,
-                "lon": 3.9662130642682314
+                "lat": 51.00982965203002,
+                "lon": 3.9662131341174245
             },
             "count": 6
         }
@@ -111,8 +111,8 @@ The response for the above aggregation:
                    "doc_count": 3,
                    "centroid": {
                       "location": {
-                         "lat": 52.371655642054975,
-                         "lon": 4.9095632415264845
+                         "lat": 52.371655656024814,
+                         "lon": 4.909563297405839
                       },
                       "count": 3
                    }
@@ -123,7 +123,7 @@ The response for the above aggregation:
                    "centroid": {
                       "location": {
                          "lat": 48.86055548675358,
-                         "lon": 2.331694420427084
+                         "lon": 2.3316944623366
                       },
                       "count": 2
                    }

+ 35 - 25
server/src/main/java/org/elasticsearch/search/aggregations/metrics/GeoCentroidAggregator.java

@@ -23,6 +23,7 @@ import org.apache.lucene.index.LeafReaderContext;
 import org.elasticsearch.common.geo.GeoPoint;
 import org.elasticsearch.common.lease.Releasables;
 import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.common.util.DoubleArray;
 import org.elasticsearch.common.util.LongArray;
 import org.elasticsearch.index.fielddata.MultiGeoPointValues;
 import org.elasticsearch.search.aggregations.Aggregator;
@@ -42,7 +43,7 @@ import java.util.Map;
  */
 final class GeoCentroidAggregator extends MetricsAggregator {
     private final ValuesSource.GeoPoint valuesSource;
-    private LongArray centroids;
+    private DoubleArray lonSum, lonCompensations, latSum, latCompensations;
     private LongArray counts;
 
     GeoCentroidAggregator(String name, SearchContext context, Aggregator parent,
@@ -52,7 +53,10 @@ final class GeoCentroidAggregator extends MetricsAggregator {
         this.valuesSource = valuesSource;
         if (valuesSource != null) {
             final BigArrays bigArrays = context.bigArrays();
-            centroids = bigArrays.newLongArray(1, true);
+            lonSum = bigArrays.newDoubleArray(1, true);
+            lonCompensations = bigArrays.newDoubleArray(1, true);
+            latSum = bigArrays.newDoubleArray(1, true);
+            latCompensations = bigArrays.newDoubleArray(1, true);
             counts = bigArrays.newLongArray(1, true);
         }
     }
@@ -67,33 +71,41 @@ final class GeoCentroidAggregator extends MetricsAggregator {
         return new LeafBucketCollectorBase(sub, values) {
             @Override
             public void collect(int doc, long bucket) throws IOException {
-                centroids = bigArrays.grow(centroids, bucket + 1);
+                latSum = bigArrays.grow(latSum, bucket + 1);
+                lonSum = bigArrays.grow(lonSum, bucket + 1);
+                lonCompensations = bigArrays.grow(lonCompensations, bucket + 1);
+                latCompensations = bigArrays.grow(latCompensations, bucket + 1);
                 counts = bigArrays.grow(counts, bucket + 1);
 
                 if (values.advanceExact(doc)) {
                     final int valueCount = values.docValueCount();
-                    double[] pt = new double[2];
-                    // get the previously accumulated number of counts
-                    long prevCounts = counts.get(bucket);
                     // increment by the number of points for this document
                     counts.increment(bucket, valueCount);
-                    // get the previous GeoPoint if a moving avg was
-                    // computed
-                    if (prevCounts > 0) {
-                        final long mortonCode = centroids.get(bucket);
-                        pt[0] = InternalGeoCentroid.decodeLongitude(mortonCode);
-                        pt[1] = InternalGeoCentroid.decodeLatitude(mortonCode);
-                    }
-                    // update the moving average
+                    // Compute the sum of double values with Kahan summation algorithm which is more
+                    // accurate than naive summation.
+                    double sumLat = latSum.get(bucket);
+                    double compensationLat = latCompensations.get(bucket);
+                    double sumLon = lonSum.get(bucket);
+                    double compensationLon = lonCompensations.get(bucket);
+
+                    // update the sum
                     for (int i = 0; i < valueCount; ++i) {
                         GeoPoint value = values.nextValue();
-                        pt[0] = pt[0] + (value.getLon() - pt[0]) / ++prevCounts;
-                        pt[1] = pt[1] + (value.getLat() - pt[1]) / prevCounts;
+                        //latitude
+                        double correctedLat = value.getLat() - compensationLat;
+                        double newSumLat = sumLat + correctedLat;
+                        compensationLat = (newSumLat - sumLat) - correctedLat;
+                        sumLat = newSumLat;
+                        //longitude
+                        double correctedLon = value.getLon() - compensationLon;
+                        double newSumLon = sumLon + correctedLon;
+                        compensationLon = (newSumLon - sumLon) - correctedLon;
+                        sumLon = newSumLon;
                     }
-                    // TODO: we do not need to interleave the lat and lon
-                    // bits here
-                    // should we just store them contiguously?
-                    centroids.set(bucket, InternalGeoCentroid.encodeLatLon(pt[1], pt[0]));
+                    lonSum.set(bucket, sumLon);
+                    lonCompensations.set(bucket, compensationLon);
+                    latSum.set(bucket, sumLat);
+                    latCompensations.set(bucket, compensationLat);
                 }
             }
         };
@@ -101,14 +113,12 @@ final class GeoCentroidAggregator extends MetricsAggregator {
 
     @Override
     public InternalAggregation buildAggregation(long bucket) {
-        if (valuesSource == null || bucket >= centroids.size()) {
+        if (valuesSource == null || bucket >= counts.size()) {
             return buildEmptyAggregation();
         }
         final long bucketCount = counts.get(bucket);
-        final long mortonCode = centroids.get(bucket);
         final GeoPoint bucketCentroid = (bucketCount > 0)
-                ? new GeoPoint(InternalGeoCentroid.decodeLatitude(mortonCode),
-                        InternalGeoCentroid.decodeLongitude(mortonCode))
+                ? new GeoPoint(latSum.get(bucket) / bucketCount, lonSum.get(bucket) / bucketCount)
                 : null;
         return new InternalGeoCentroid(name, bucketCentroid , bucketCount, pipelineAggregators(), metaData());
     }
@@ -120,6 +130,6 @@ final class GeoCentroidAggregator extends MetricsAggregator {
 
     @Override
     public void doClose() {
-        Releasables.close(centroids, counts);
+        Releasables.close(latSum, latCompensations, lonSum, lonCompensations, counts);
     }
 }

+ 14 - 4
server/src/main/java/org/elasticsearch/search/aggregations/metrics/InternalGeoCentroid.java

@@ -20,6 +20,7 @@
 package org.elasticsearch.search.aggregations.metrics;
 
 import org.apache.lucene.geo.GeoEncodingUtils;
+import org.elasticsearch.Version;
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.geo.GeoPoint;
 import org.elasticsearch.common.io.stream.StreamInput;
@@ -69,8 +70,13 @@ public class InternalGeoCentroid extends InternalAggregation implements GeoCentr
         super(in);
         count = in.readVLong();
         if (in.readBoolean()) {
-            final long hash = in.readLong();
-            centroid = new GeoPoint(decodeLatitude(hash), decodeLongitude(hash));
+            if (in.getVersion().onOrAfter(Version.V_7_1_0)) {
+                centroid = new GeoPoint(in.readDouble(), in.readDouble());
+            } else {
+                final long hash = in.readLong();
+                centroid = new GeoPoint(decodeLatitude(hash), decodeLongitude(hash));
+            }
+
         } else {
             centroid = null;
         }
@@ -81,8 +87,12 @@ public class InternalGeoCentroid extends InternalAggregation implements GeoCentr
         out.writeVLong(count);
         if (centroid != null) {
             out.writeBoolean(true);
-            // should we just write lat and lon separately?
-            out.writeLong(encodeLatLon(centroid.lat(), centroid.lon()));
+            if (out.getVersion().onOrAfter(Version.V_7_1_0)) {
+                out.writeDouble(centroid.lat());
+                out.writeDouble(centroid.lon());
+            } else {
+                out.writeLong(encodeLatLon(centroid.lat(), centroid.lon()));
+            }
         } else {
             out.writeBoolean(false);
         }

+ 1 - 3
server/src/test/java/org/elasticsearch/search/aggregations/metrics/GeoCentroidAggregatorTests.java

@@ -29,8 +29,6 @@ import org.elasticsearch.common.geo.GeoPoint;
 import org.elasticsearch.index.mapper.GeoPointFieldMapper;
 import org.elasticsearch.index.mapper.MappedFieldType;
 import org.elasticsearch.search.aggregations.AggregatorTestCase;
-import org.elasticsearch.search.aggregations.metrics.GeoCentroidAggregationBuilder;
-import org.elasticsearch.search.aggregations.metrics.InternalGeoCentroid;
 import org.elasticsearch.search.aggregations.support.AggregationInspectionHelper;
 import org.elasticsearch.test.geo.RandomGeoGenerator;
 
@@ -38,7 +36,7 @@ import java.io.IOException;
 
 public class GeoCentroidAggregatorTests extends AggregatorTestCase {
 
-    private static final double GEOHASH_TOLERANCE = 1E-4D;
+    private static final double GEOHASH_TOLERANCE = 1E-6D;
 
     public void testEmpty() throws Exception {
         try (Directory dir = newDirectory();