Browse Source

[Transform] Support `range` aggregation in transform (#86501)

Przemysław Witek 3 years ago
parent
commit
70e37ae7c6

+ 1 - 0
docs/reference/rest-api/common-parms.asciidoc

@@ -703,6 +703,7 @@ currently supported:
 * <<search-aggregations-metrics-min-aggregation,Min>>
 * <<search-aggregations-bucket-missing-aggregation,Missing>>
 * <<search-aggregations-metrics-percentile-aggregation,Percentiles>>
+* <<search-aggregations-bucket-range-aggregation,Range>>
 * <<search-aggregations-bucket-rare-terms-aggregation, Rare Terms>>
 * <<search-aggregations-metrics-scripted-metric-aggregation,Scripted metric>>
 * <<search-aggregations-metrics-stats-aggregation,Stats>>

+ 98 - 0
x-pack/plugin/transform/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/transform/integration/TransformPivotRestIT.java

@@ -1994,6 +1994,104 @@ public class TransformPivotRestIT extends TransformRestTestCase {
         assertEquals(5, actual.longValue());
     }
 
+    public void testPivotWithRanges() throws Exception {
+        String transformId = "range_pivot";
+        String transformIndex = "range_pivot_reviews";
+        boolean keyed = randomBoolean();
+        setupDataAccessRole(DATA_ACCESS_ROLE, REVIEWS_INDEX_NAME, transformIndex);
+        final Request createTransformRequest = createRequestWithAuth(
+            "PUT",
+            getTransformEndpoint() + transformId,
+            BASIC_AUTH_VALUE_TRANSFORM_ADMIN_WITH_SOME_DATA_ACCESS
+        );
+        String config = """
+            {
+              "source": {
+                "index": "%s"
+              },
+              "dest": {
+                "index": "%s"
+              },
+              "frequency": "1s",
+              "pivot": {
+                "group_by": {
+                  "reviewer": {
+                    "terms": {
+                      "field": "user_id"
+                    }
+                  }
+                },
+                "aggregations": {
+                  "avg_rating": {
+                    "avg": {
+                      "field": "stars"
+                    }
+                  },
+                  "ranges": {
+                    "range": {
+                      "field": "stars",
+                      "keyed": %s,
+                      "ranges": [ { "to": 2 }, { "from": 2, "to": 3.99 }, { "from": 4 } ]
+                    }
+                  },
+                  "ranges-avg": {
+                    "range": {
+                      "field": "stars",
+                      "keyed": %s,
+                      "ranges": [ { "to": 2 }, { "from": 2, "to": 3.99 }, { "from": 4 } ]
+                    },
+                    "aggs": { "avg_stars": { "avg": { "field": "stars" } } }
+                  }
+                }
+              }
+            }""".formatted(REVIEWS_INDEX_NAME, transformIndex, keyed, keyed);
+        createTransformRequest.setJsonEntity(config);
+        Map<String, Object> createTransformResponse = entityAsMap(client().performRequest(createTransformRequest));
+        assertThat(createTransformResponse.get("acknowledged"), equalTo(Boolean.TRUE));
+
+        startAndWaitForTransform(transformId, transformIndex);
+        assertTrue(indexExists(transformIndex));
+
+        // check destination index mappings
+        Map<String, Object> mappingsResult = getAsMap(transformIndex + "/_mapping");
+        assertThat(
+            XContentMapValues.extractValue("range_pivot_reviews.mappings.properties.ranges.properties", mappingsResult),
+            is(equalTo(Map.of("4-*", Map.of("type", "long"), "2-3_99", Map.of("type", "long"), "*-2", Map.of("type", "long"))))
+        );
+        assertThat(
+            XContentMapValues.extractValue("range_pivot_reviews.mappings.properties.ranges-avg.properties", mappingsResult),
+            is(
+                equalTo(
+                    Map.of(
+                        "4-*",
+                        Map.of("properties", Map.of("avg_stars", Map.of("type", "double"))),
+                        "2-3_99",
+                        Map.of("properties", Map.of("avg_stars", Map.of("type", "double"))),
+                        "*-2",
+                        Map.of("properties", Map.of("avg_stars", Map.of("type", "double")))
+                    )
+                )
+            )
+        );
+
+        // get and check some users
+        Map<String, Object> searchResult = getAsMap(transformIndex + "/_search?q=reviewer:user_11");
+        assertEquals(1, XContentMapValues.extractValue("hits.total.value", searchResult));
+        Number actual = (Number) ((List<?>) XContentMapValues.extractValue("hits.hits._source.ranges.*-2", searchResult)).get(0);
+        assertEquals(5, actual.longValue());
+        actual = (Number) ((List<?>) XContentMapValues.extractValue("hits.hits._source.ranges.2-3_99", searchResult)).get(0);
+        assertEquals(2, actual.longValue());
+        actual = (Number) ((List<?>) XContentMapValues.extractValue("hits.hits._source.ranges.4-*", searchResult)).get(0);
+        assertEquals(19, actual.longValue());
+
+        actual = (Number) ((List<?>) XContentMapValues.extractValue("hits.hits._source.ranges-avg.*-2.avg_stars", searchResult)).get(0);
+        assertEquals(1.0, actual.doubleValue(), 1E-6);
+        actual = (Number) ((List<?>) XContentMapValues.extractValue("hits.hits._source.ranges-avg.2-3_99.avg_stars", searchResult)).get(0);
+        assertEquals(3.0, actual.doubleValue(), 1E-6);
+        actual = (Number) ((List<?>) XContentMapValues.extractValue("hits.hits._source.ranges-avg.4-*.avg_stars", searchResult)).get(0);
+        assertEquals(4.6842105, actual.doubleValue(), 1E-6);
+    }
+
     public void testPivotWithFilter() throws Exception {
         String transformId = "filter_pivot";
         String transformIndex = "filter_pivot_reviews";

+ 35 - 2
x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/transforms/pivot/AggregationResultUtils.java

@@ -19,6 +19,7 @@ import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation;
 import org.elasticsearch.search.aggregations.bucket.SingleBucketAggregation;
 import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregation;
 import org.elasticsearch.search.aggregations.bucket.geogrid.GeoTileUtils;
+import org.elasticsearch.search.aggregations.bucket.range.Range;
 import org.elasticsearch.search.aggregations.metrics.GeoBounds;
 import org.elasticsearch.search.aggregations.metrics.GeoCentroid;
 import org.elasticsearch.search.aggregations.metrics.MultiValueAggregation;
@@ -43,6 +44,8 @@ import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
+import java.util.function.Function;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
 
@@ -59,6 +62,7 @@ public final class AggregationResultUtils {
         tempMap.put(GeoCentroid.class.getName(), new GeoCentroidAggExtractor());
         tempMap.put(GeoBounds.class.getName(), new GeoBoundsAggExtractor());
         tempMap.put(Percentiles.class.getName(), new PercentilesAggExtractor());
+        tempMap.put(Range.class.getName(), new RangeAggExtractor());
         tempMap.put(SingleBucketAggregation.class.getName(), new SingleBucketAggExtractor());
         tempMap.put(MultiBucketsAggregation.class.getName(), new MultiBucketsAggExtractor());
         tempMap.put(GeoShapeMetricAggregation.class.getName(), new GeoShapeMetricAggExtractor());
@@ -166,6 +170,10 @@ public final class AggregationResultUtils {
             // TODO: can the Percentiles extractor be removed?
         } else if (aggregation instanceof Percentiles) {
             return TYPE_VALUE_EXTRACTOR_MAP.get(Percentiles.class.getName());
+            // note: range is also a multi bucket agg, therefore check range first
+            // TODO: can the Range extractor be removed?
+        } else if (aggregation instanceof Range) {
+            return TYPE_VALUE_EXTRACTOR_MAP.get(Range.class.getName());
         } else if (aggregation instanceof MultiValue) {
             return TYPE_VALUE_EXTRACTOR_MAP.get(MultiValue.class.getName());
         } else if (aggregation instanceof MultiValueAggregation) {
@@ -334,6 +342,19 @@ public final class AggregationResultUtils {
         }
     }
 
+    static class RangeAggExtractor extends MultiBucketsAggExtractor {
+
+        RangeAggExtractor() {
+            super(RangeAggExtractor::transformBucketKey);
+        }
+
+        private static String transformBucketKey(String bucketKey) {
+            return bucketKey.replace(".0-", "-")  // from: convert double to integer
+                .replaceAll("\\.0$", "")  // to: convert double to integer
+                .replace('.', '_');  // convert remaining dots with underscores so that the key prefix is not treated as object
+        }
+    }
+
     static class SingleBucketAggExtractor implements AggValueExtractor {
         @Override
         public Object value(Aggregation agg, Map<String, String> fieldTypeMap, String lookupFieldPrefix) {
@@ -360,6 +381,17 @@ public final class AggregationResultUtils {
     }
 
     static class MultiBucketsAggExtractor implements AggValueExtractor {
+
+        private final Function<String, String> bucketKeyTransfomer;
+
+        MultiBucketsAggExtractor() {
+            this(Function.identity());
+        }
+
+        MultiBucketsAggExtractor(Function<String, String> bucketKeyTransfomer) {
+            this.bucketKeyTransfomer = Objects.requireNonNull(bucketKeyTransfomer);
+        }
+
         @Override
         public Object value(Aggregation agg, Map<String, String> fieldTypeMap, String lookupFieldPrefix) {
             MultiBucketsAggregation aggregation = (MultiBucketsAggregation) agg;
@@ -367,8 +399,9 @@ public final class AggregationResultUtils {
             HashMap<String, Object> nested = new HashMap<>();
 
             for (MultiBucketsAggregation.Bucket bucket : aggregation.getBuckets()) {
+                String bucketKey = bucketKeyTransfomer.apply(bucket.getKeyAsString());
                 if (bucket.getAggregations().iterator().hasNext() == false) {
-                    nested.put(bucket.getKeyAsString(), bucket.getDocCount());
+                    nested.put(bucketKey, bucket.getDocCount());
                 } else {
                     HashMap<String, Object> nestedBucketObject = new HashMap<>();
                     for (Aggregation subAgg : bucket.getAggregations()) {
@@ -381,7 +414,7 @@ public final class AggregationResultUtils {
                             )
                         );
                     }
-                    nested.put(bucket.getKeyAsString(), nestedBucketObject);
+                    nested.put(bucketKey, nestedBucketObject);
                 }
             }
             return nested;

+ 33 - 5
x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/transforms/pivot/TransformAggregations.java

@@ -9,6 +9,8 @@ package org.elasticsearch.xpack.transform.transforms.pivot;
 
 import org.elasticsearch.core.Tuple;
 import org.elasticsearch.search.aggregations.AggregationBuilder;
+import org.elasticsearch.search.aggregations.bucket.range.RangeAggregationBuilder;
+import org.elasticsearch.search.aggregations.bucket.range.RangeAggregator.Range;
 import org.elasticsearch.search.aggregations.metrics.PercentilesAggregationBuilder;
 import org.elasticsearch.search.aggregations.support.ValuesSourceAggregationBuilder;
 import org.elasticsearch.xpack.transform.utils.OutputFieldNameConverter;
@@ -70,7 +72,6 @@ public final class TransformAggregations {
         "matrix_stats",
         "nested",
         "percentile_ranks",
-        "range",
         "random_sampler",
         "reverse_nested",
         "sampler",
@@ -112,6 +113,7 @@ public final class TransformAggregations {
         BUCKET_SELECTOR("bucket_selector", DYNAMIC),
         BUCKET_SCRIPT("bucket_script", DYNAMIC),
         PERCENTILES("percentiles", DOUBLE),
+        RANGE("range", LONG),
         FILTER("filter", LONG),
         TERMS("terms", FLATTENED),
         RARE_TERMS("rare_terms", FLATTENED),
@@ -207,18 +209,37 @@ public final class TransformAggregations {
         // todo: can this be removed?
         if (agg instanceof PercentilesAggregationBuilder percentilesAgg) {
 
-            // note: eclipse does not like p -> agg.getType()
             // the merge function (p1, p2) -> p1 ignores duplicates
             return new Tuple<>(
                 Collections.emptyMap(),
                 Arrays.stream(percentilesAgg.percentiles())
                     .mapToObj(OutputFieldNameConverter::fromDouble)
-                    .collect(
-                        Collectors.toMap(p -> percentilesAgg.getName() + "." + p, p -> { return percentilesAgg.getType(); }, (p1, p2) -> p1)
-                    )
+                    .collect(Collectors.toMap(p -> percentilesAgg.getName() + "." + p, p -> percentilesAgg.getType(), (p1, p2) -> p1))
             );
         }
 
+        if (agg instanceof RangeAggregationBuilder rangeAgg) {
+            HashMap<String, String> outputTypes = new HashMap<>();
+            HashMap<String, String> inputTypes = new HashMap<>();
+            for (Range range : rangeAgg.ranges()) {
+                String fieldName = rangeAgg.getName() + "." + generateKeyForRange(range.getFrom(), range.getTo());
+                if (rangeAgg.getSubAggregations().isEmpty()) {
+                    outputTypes.put(fieldName, AggregationType.RANGE.getName());
+                    continue;
+                }
+                for (AggregationBuilder subAgg : rangeAgg.getSubAggregations()) {
+                    Tuple<Map<String, String>, Map<String, String>> subAggregationTypes = getAggregationInputAndOutputTypes(subAgg);
+                    for (Entry<String, String> subAggOutputType : subAggregationTypes.v2().entrySet()) {
+                        outputTypes.put(String.join(".", fieldName, subAggOutputType.getKey()), subAggOutputType.getValue());
+                    }
+                    for (Entry<String, String> subAggInputType : subAggregationTypes.v1().entrySet()) {
+                        inputTypes.put(String.join(".", fieldName, subAggInputType.getKey()), subAggInputType.getValue());
+                    }
+                }
+            }
+            return new Tuple<>(inputTypes, outputTypes);
+        }
+
         // does the agg specify output field names
         Optional<Set<String>> outputFieldNames = agg.getOutputFieldNames();
         if (outputFieldNames.isPresent()) {
@@ -267,4 +288,11 @@ public final class TransformAggregations {
         return new Tuple<>(Collections.emptyMap(), Collections.singletonMap(agg.getName(), agg.getType()));
     }
 
+    // Visible for testing
+    static String generateKeyForRange(double from, double to) {
+        return new StringBuilder().append(Double.isInfinite(from) ? "*" : OutputFieldNameConverter.fromDouble(from))
+            .append("-")
+            .append(Double.isInfinite(to) ? "*" : OutputFieldNameConverter.fromDouble(to))
+            .toString();
+    }
 }

+ 61 - 11
x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/pivot/AggregationResultUtilsTests.java

@@ -10,16 +10,20 @@ package org.elasticsearch.xpack.transform.transforms.pivot;
 import org.elasticsearch.common.geo.GeoPoint;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.script.Script;
+import org.elasticsearch.search.DocValueFormat;
 import org.elasticsearch.search.aggregations.Aggregation;
 import org.elasticsearch.search.aggregations.AggregationBuilder;
 import org.elasticsearch.search.aggregations.AggregationBuilders;
 import org.elasticsearch.search.aggregations.AggregationReduceContext;
 import org.elasticsearch.search.aggregations.InternalAggregation;
+import org.elasticsearch.search.aggregations.InternalAggregations;
 import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
 import org.elasticsearch.search.aggregations.PipelineAggregatorBuilders;
 import org.elasticsearch.search.aggregations.bucket.SingleBucketAggregation;
 import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregation;
 import org.elasticsearch.search.aggregations.bucket.composite.ParsedComposite;
+import org.elasticsearch.search.aggregations.bucket.range.InternalRange;
+import org.elasticsearch.search.aggregations.bucket.range.Range;
 import org.elasticsearch.search.aggregations.bucket.terms.DoubleTerms;
 import org.elasticsearch.search.aggregations.bucket.terms.LongTerms;
 import org.elasticsearch.search.aggregations.bucket.terms.ParsedDoubleTerms;
@@ -72,7 +76,6 @@ import org.elasticsearch.xpack.core.transform.transforms.pivot.GroupConfig;
 import org.elasticsearch.xpack.transform.transforms.pivot.AggregationResultUtils.BucketKeyExtractor;
 
 import java.io.IOException;
-import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
@@ -873,7 +876,7 @@ public class AggregationResultUtilsTests extends ESTestCase {
             expectedObject.put("type", type);
             double lat = randomDoubleBetween(-90.0, 90.0, false);
             double lon = randomDoubleBetween(-180.0, 180.0, false);
-            expectedObject.put("coordinates", Arrays.asList(lon, lat));
+            expectedObject.put("coordinates", asList(lon, lat));
             agg = createGeoBounds(new GeoPoint(lat, lon), new GeoPoint(lat, lon));
             assertThat(AggregationResultUtils.getExtractor(agg).value(agg, Collections.emptyMap(), ""), equalTo(expectedObject));
         }
@@ -920,12 +923,12 @@ public class AggregationResultUtilsTests extends ESTestCase {
             List<List<Double[]>> coordinates = (List<List<Double[]>>) geoJson.get("coordinates");
             assertThat(coordinates.size(), equalTo(1));
             assertThat(coordinates.get(0).size(), equalTo(5));
-            List<List<Double>> expected = Arrays.asList(
-                Arrays.asList(lon, lat),
-                Arrays.asList(lon2, lat),
-                Arrays.asList(lon2, lat2),
-                Arrays.asList(lon, lat2),
-                Arrays.asList(lon, lat)
+            List<List<Double>> expected = asList(
+                asList(lon, lat),
+                asList(lon2, lat),
+                asList(lon2, lat2),
+                asList(lon, lat2),
+                asList(lon, lat)
             );
             for (int j = 0; j < 5; j++) {
                 Double[] coordinate = coordinates.get(0).get(j);
@@ -947,7 +950,7 @@ public class AggregationResultUtilsTests extends ESTestCase {
     public void testPercentilesAggExtractor() {
         Aggregation agg = createPercentilesAgg(
             "p_agg",
-            Arrays.asList(new Percentile(1, 0), new Percentile(50, 22.2), new Percentile(99, 43.3), new Percentile(99.5, 100.3))
+            asList(new Percentile(1, 0), new Percentile(50, 22.2), new Percentile(99, 43.3), new Percentile(99.5, 100.3))
         );
         assertThat(
             AggregationResultUtils.getExtractor(agg).value(agg, Collections.emptyMap(), ""),
@@ -956,17 +959,64 @@ public class AggregationResultUtilsTests extends ESTestCase {
     }
 
     public void testPercentilesAggExtractorNaN() {
-        Aggregation agg = createPercentilesAgg("p_agg", Arrays.asList(new Percentile(1, Double.NaN), new Percentile(50, Double.NaN)));
+        Aggregation agg = createPercentilesAgg("p_agg", asList(new Percentile(1, Double.NaN), new Percentile(50, Double.NaN)));
         assertThat(AggregationResultUtils.getExtractor(agg).value(agg, Collections.emptyMap(), ""), equalTo(asMap("1", null, "50", null)));
     }
 
+    @SuppressWarnings("unchecked")
+    public static Range createRangeAgg(String name, List<InternalRange.Bucket> buckets) {
+        Range agg = mock(Range.class);
+        when(agg.getName()).thenReturn(name);
+        when(agg.getBuckets()).thenReturn((List) buckets);
+        return agg;
+    }
+
+    public void testRangeAggExtractor() {
+        Aggregation agg = createRangeAgg(
+            "p_agg",
+            asList(
+                new InternalRange.Bucket(null, Double.NEGATIVE_INFINITY, 10.5, 10, InternalAggregations.EMPTY, false, DocValueFormat.RAW),
+                new InternalRange.Bucket(null, 10.5, 19.5, 30, InternalAggregations.EMPTY, false, DocValueFormat.RAW),
+                new InternalRange.Bucket(null, 19.5, 200, 30, InternalAggregations.EMPTY, false, DocValueFormat.RAW),
+                new InternalRange.Bucket(null, 20, Double.POSITIVE_INFINITY, 0, InternalAggregations.EMPTY, false, DocValueFormat.RAW),
+                new InternalRange.Bucket(null, -10, -5, 0, InternalAggregations.EMPTY, false, DocValueFormat.RAW),
+                new InternalRange.Bucket(null, -11.0, -6.0, 0, InternalAggregations.EMPTY, false, DocValueFormat.RAW),
+                new InternalRange.Bucket(null, -11.0, 0, 0, InternalAggregations.EMPTY, false, DocValueFormat.RAW),
+                new InternalRange.Bucket("custom-0", 0, 10, 777, InternalAggregations.EMPTY, false, DocValueFormat.RAW)
+            )
+        );
+        assertThat(
+            AggregationResultUtils.getExtractor(agg).value(agg, Collections.emptyMap(), ""),
+            equalTo(
+                asMap(
+                    "*-10_5",
+                    10L,
+                    "10_5-19_5",
+                    30L,
+                    "19_5-200",
+                    30L,
+                    "20-*",
+                    0L,
+                    "-10--5",
+                    0L,
+                    "-11--6",
+                    0L,
+                    "-11-0",
+                    0L,
+                    "custom-0",
+                    777L
+                )
+            )
+        );
+    }
+
     public static SingleBucketAggregation createSingleBucketAgg(String name, long docCount, Aggregation... subAggregations) {
         SingleBucketAggregation agg = mock(SingleBucketAggregation.class);
         when(agg.getDocCount()).thenReturn(docCount);
         when(agg.getName()).thenReturn(name);
         if (subAggregations != null) {
             org.elasticsearch.search.aggregations.Aggregations subAggs = new org.elasticsearch.search.aggregations.Aggregations(
-                Arrays.asList(subAggregations)
+                asList(subAggregations)
             );
             when(agg.getAggregations()).thenReturn(subAggs);
         } else {

+ 16 - 1
x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/pivot/AggregationSchemaAndResultTests.java

@@ -119,6 +119,16 @@ public class AggregationSchemaAndResultTests extends ESTestCase {
         // percentile produces 1 output per percentile + 1 for the parent object
         aggs.addAggregator(AggregationBuilders.percentiles("p_rating").field("long_stars").percentiles(1, 5, 10, 50, 99.9));
 
+        // range produces 1 output per range + 1 for the parent object
+        aggs.addAggregator(
+            AggregationBuilders.range("some_range")
+                .field("long_stars")
+                .addUnboundedTo(10.5)
+                .addRange(10.5, 19.5)
+                .addRange(19.5, 20)
+                .addUnboundedFrom(20)
+        );
+
         // scripted metric produces no output because its dynamic
         aggs.addAggregator(AggregationBuilders.scriptedMetric("collapsed_ratings"));
 
@@ -134,7 +144,7 @@ public class AggregationSchemaAndResultTests extends ESTestCase {
         this.<Map<String, String>>assertAsync(
             listener -> SchemaUtil.deduceMappings(client, pivotConfig, new String[] { "source-index" }, emptyMap(), listener),
             mappings -> {
-                assertEquals(numGroupsWithoutScripts + 10, mappings.size());
+                assertEquals("Mappings were: " + mappings, numGroupsWithoutScripts + 15, mappings.size());
                 assertEquals("long", mappings.get("max_rating"));
                 assertEquals("double", mappings.get("avg_rating"));
                 assertEquals("long", mappings.get("count_rating"));
@@ -144,6 +154,11 @@ public class AggregationSchemaAndResultTests extends ESTestCase {
                 assertEquals("double", mappings.get("p_rating.5"));
                 assertEquals("double", mappings.get("p_rating.10"));
                 assertEquals("double", mappings.get("p_rating.99_9"));
+                assertEquals("object", mappings.get("some_range"));
+                assertEquals("long", mappings.get("some_range.*-10_5"));
+                assertEquals("long", mappings.get("some_range.10_5-19_5"));
+                assertEquals("long", mappings.get("some_range.19_5-20"));
+                assertEquals("long", mappings.get("some_range.20-*"));
 
                 Aggregation agg = AggregationResultUtilsTests.createSingleMetricAgg("avg_rating", 33.3, "33.3");
                 assertThat(AggregationResultUtils.getExtractor(agg).value(agg, mappings, ""), equalTo(33.3));

+ 111 - 0
x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/pivot/TransformAggregationsTests.java

@@ -13,7 +13,9 @@ import org.elasticsearch.core.Tuple;
 import org.elasticsearch.index.query.TermQueryBuilder;
 import org.elasticsearch.search.SearchModule;
 import org.elasticsearch.search.aggregations.AggregationBuilder;
+import org.elasticsearch.search.aggregations.AggregationBuilders;
 import org.elasticsearch.search.aggregations.bucket.filter.FilterAggregationBuilder;
+import org.elasticsearch.search.aggregations.bucket.range.RangeAggregationBuilder;
 import org.elasticsearch.search.aggregations.matrix.MatrixAggregationPlugin;
 import org.elasticsearch.search.aggregations.metrics.MaxAggregationBuilder;
 import org.elasticsearch.search.aggregations.metrics.MinAggregationBuilder;
@@ -27,6 +29,9 @@ import java.util.List;
 import java.util.Map;
 import java.util.stream.Collectors;
 
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.is;
+
 public class TransformAggregationsTests extends ESTestCase {
     public void testResolveTargetMapping() {
 
@@ -63,6 +68,12 @@ public class TransformAggregationsTests extends ESTestCase {
         assertEquals("double", TransformAggregations.resolveTargetMapping("sum", "half_float"));
         assertEquals("double", TransformAggregations.resolveTargetMapping("sum", null));
 
+        // range
+        assertEquals("long", TransformAggregations.resolveTargetMapping("range", "int"));
+        assertEquals("long", TransformAggregations.resolveTargetMapping("range", "double"));
+        assertEquals("long", TransformAggregations.resolveTargetMapping("range", "half_float"));
+        assertEquals("long", TransformAggregations.resolveTargetMapping("range", "scaled_float"));
+
         // geo_centroid
         assertEquals("geo_point", TransformAggregations.resolveTargetMapping("geo_centroid", "geo_point"));
         assertEquals("geo_point", TransformAggregations.resolveTargetMapping("geo_centroid", null));
@@ -173,6 +184,17 @@ public class TransformAggregationsTests extends ESTestCase {
         assertEquals("percentiles", outputTypes.get("percentiles.1"));
         assertEquals("percentiles", outputTypes.get("percentiles.5"));
         assertEquals("percentiles", outputTypes.get("percentiles.10"));
+
+        percentialAggregationBuilder = new PercentilesAggregationBuilder("percentiles", new double[] { 1.2, 5.5, 10.7 }, null);
+
+        inputAndOutputTypes = TransformAggregations.getAggregationInputAndOutputTypes(percentialAggregationBuilder);
+        assertTrue(inputAndOutputTypes.v1().isEmpty());
+        outputTypes = inputAndOutputTypes.v2();
+
+        assertEquals(3, outputTypes.size());
+        assertEquals("percentiles", outputTypes.get("percentiles.1_2"));
+        assertEquals("percentiles", outputTypes.get("percentiles.5_5"));
+        assertEquals("percentiles", outputTypes.get("percentiles.10_7"));
     }
 
     public void testGetAggregationOutputTypesStats() {
@@ -190,6 +212,84 @@ public class TransformAggregationsTests extends ESTestCase {
         assertEquals("stats", outputTypes.get("stats.sum"));
     }
 
+    public void testGetAggregationOutputTypesRange() {
+        {
+            AggregationBuilder rangeAggregationBuilder = new RangeAggregationBuilder("range_agg_name").addUnboundedTo(100)
+                .addRange(100, 200)
+                .addUnboundedFrom(200);
+            var inputAndOutputTypes = TransformAggregations.getAggregationInputAndOutputTypes(rangeAggregationBuilder);
+            assertThat(
+                inputAndOutputTypes,
+                is(
+                    equalTo(
+                        Tuple.tuple(
+                            Map.of(),
+                            Map.of("range_agg_name.*-100", "range", "range_agg_name.100-200", "range", "range_agg_name.200-*", "range")
+                        )
+                    )
+                )
+            );
+        }
+
+        {
+            AggregationBuilder rangeAggregationBuilder = new RangeAggregationBuilder("range_agg_name").addUnboundedTo(100.5)
+                .addRange(100.5, 200.7)
+                .addUnboundedFrom(200.7);
+            var inputAndOutputTypes = TransformAggregations.getAggregationInputAndOutputTypes(rangeAggregationBuilder);
+            assertThat(
+                inputAndOutputTypes,
+                is(
+                    equalTo(
+                        Tuple.tuple(
+                            Map.of(),
+                            Map.of(
+                                "range_agg_name.*-100_5",
+                                "range",
+                                "range_agg_name.100_5-200_7",
+                                "range",
+                                "range_agg_name.200_7-*",
+                                "range"
+                            )
+                        )
+                    )
+                )
+            );
+        }
+
+        {
+            AggregationBuilder rangeAggregationBuilder = new RangeAggregationBuilder("range_agg_name").addUnboundedTo(100.5)
+                .addRange(100.5, 200.7)
+                .addUnboundedFrom(200.7)
+                .subAggregation(AggregationBuilders.avg("my-avg").field("my-field"));
+            var inputAndOutputTypes = TransformAggregations.getAggregationInputAndOutputTypes(rangeAggregationBuilder);
+            assertThat(
+                inputAndOutputTypes,
+                is(
+                    equalTo(
+                        Tuple.tuple(
+                            Map.of(
+                                "range_agg_name.*-100_5.my-avg",
+                                "my-field",
+                                "range_agg_name.100_5-200_7.my-avg",
+                                "my-field",
+                                "range_agg_name.200_7-*.my-avg",
+                                "my-field"
+                            ),
+                            Map.of(
+                                "range_agg_name.*-100_5.my-avg",
+                                "avg",
+                                "range_agg_name.100_5-200_7.my-avg",
+                                "avg",
+                                "range_agg_name.200_7-*.my-avg",
+                                "avg"
+                            )
+                        )
+                    )
+                )
+            );
+        }
+    }
+
     public void testGetAggregationOutputTypesSubAggregations() {
 
         AggregationBuilder filterAggregationBuilder = new FilterAggregationBuilder("filter_1", new TermQueryBuilder("type", "cat"));
@@ -260,4 +360,15 @@ public class TransformAggregationsTests extends ESTestCase {
         assertEquals("percentiles", outputTypes.get("filter_1.filter_2.percentiles.88_8"));
         assertEquals("percentiles", outputTypes.get("filter_1.filter_2.percentiles.99_5"));
     }
+
+    public void testGenerateKeyForRange() {
+        assertThat(TransformAggregations.generateKeyForRange(Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY), is(equalTo("*-*")));
+        assertThat(TransformAggregations.generateKeyForRange(Double.NEGATIVE_INFINITY, 0.0), is(equalTo("*-0")));
+        assertThat(TransformAggregations.generateKeyForRange(0.0, 0.0), is(equalTo("0-0")));
+        assertThat(TransformAggregations.generateKeyForRange(10.0, 10.0), is(equalTo("10-10")));
+        assertThat(TransformAggregations.generateKeyForRange(10.5, 10.5), is(equalTo("10_5-10_5")));
+        assertThat(TransformAggregations.generateKeyForRange(10.5, 19.5), is(equalTo("10_5-19_5")));
+        assertThat(TransformAggregations.generateKeyForRange(19.5, 20), is(equalTo("19_5-20")));
+        assertThat(TransformAggregations.generateKeyForRange(20, Double.POSITIVE_INFINITY), is(equalTo("20-*")));
+    }
 }

+ 1 - 0
x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/utils/OutputFieldNameConverterTests.java

@@ -25,5 +25,6 @@ public class OutputFieldNameConverterTests extends ESTestCase {
         assertEquals("NaN", OutputFieldNameConverter.fromDouble(Double.NaN));
         // infinity
         assertEquals("-Infinity", OutputFieldNameConverter.fromDouble(Double.NEGATIVE_INFINITY));
+        assertEquals("Infinity", OutputFieldNameConverter.fromDouble(Double.POSITIVE_INFINITY));
     }
 }