Browse Source

Added Sum aggregation tests and docs (#110984)

- Added SUM() agg tests (Which autogenerates docs)
- Converted non-finite doubles to nulls in aggregator

The complete set of tests depends on
https://github.com/elastic/elasticsearch/issues/110437, as commented in
code. After completion, the test can be uncommented and everything
should work fine
Iván Cea Fontenla 1 year ago
parent
commit
101775b93d

+ 2 - 2
docs/reference/esql/functions/aggregation-functions.asciidoc

@@ -17,7 +17,7 @@ The <<esql-stats-by>> command supports these aggregate functions:
 * <<esql-min>>
 * <<esql-percentile>>
 * experimental:[] <<esql-agg-st-centroid>>
-* <<esql-agg-sum>>
+* <<esql-sum>>
 * <<esql-top>>
 * <<esql-agg-values>>
 * experimental:[] <<esql-agg-weighted-avg>>
@@ -28,11 +28,11 @@ include::count-distinct.asciidoc[]
 include::median.asciidoc[]
 include::median-absolute-deviation.asciidoc[]
 include::st_centroid_agg.asciidoc[]
-include::sum.asciidoc[]
 include::layout/avg.asciidoc[]
 include::layout/max.asciidoc[]
 include::layout/min.asciidoc[]
 include::layout/percentile.asciidoc[]
+include::layout/sum.asciidoc[]
 include::layout/top.asciidoc[]
 include::values.asciidoc[]
 include::weighted-avg.asciidoc[]

+ 5 - 0
docs/reference/esql/functions/description/sum.asciidoc

@@ -0,0 +1,5 @@
+// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.
+
+*Description*
+
+The sum of a numeric expression.

+ 4 - 23
docs/reference/esql/functions/sum.asciidoc → docs/reference/esql/functions/examples/sum.asciidoc

@@ -1,22 +1,6 @@
-[discrete]
-[[esql-agg-sum]]
-=== `SUM`
+// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.
 
-*Syntax*
-
-[source,esql]
-----
-SUM(expression)
-----
-
-`expression`::
-Numeric expression.
-
-*Description*
-
-Returns the sum of a numeric expression.
-
-*Example*
+*Examples*
 
 [source.merge.styled,esql]
 ----
@@ -26,11 +10,7 @@ include::{esql-specs}/stats.csv-spec[tag=sum]
 |===
 include::{esql-specs}/stats.csv-spec[tag=sum-result]
 |===
-
-The expression can use inline functions. For example, to calculate
-the sum of each employee's maximum salary changes, apply the
-`MV_MAX` function to each row and then sum the results:
-
+The expression can use inline functions. For example, to calculate the sum of each employee's maximum salary changes, apply the `MV_MAX` function to each row and then sum the results
 [source.merge.styled,esql]
 ----
 include::{esql-specs}/stats.csv-spec[tag=docsStatsSumNestedExpression]
@@ -39,3 +19,4 @@ include::{esql-specs}/stats.csv-spec[tag=docsStatsSumNestedExpression]
 |===
 include::{esql-specs}/stats.csv-spec[tag=docsStatsSumNestedExpression-result]
 |===
+

+ 48 - 0
docs/reference/esql/functions/kibana/definition/sum.json

@@ -0,0 +1,48 @@
+{
+  "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.",
+  "type" : "agg",
+  "name" : "sum",
+  "description" : "The sum of a numeric expression.",
+  "signatures" : [
+    {
+      "params" : [
+        {
+          "name" : "number",
+          "type" : "double",
+          "optional" : false,
+          "description" : ""
+        }
+      ],
+      "variadic" : false,
+      "returnType" : "double"
+    },
+    {
+      "params" : [
+        {
+          "name" : "number",
+          "type" : "integer",
+          "optional" : false,
+          "description" : ""
+        }
+      ],
+      "variadic" : false,
+      "returnType" : "long"
+    },
+    {
+      "params" : [
+        {
+          "name" : "number",
+          "type" : "long",
+          "optional" : false,
+          "description" : ""
+        }
+      ],
+      "variadic" : false,
+      "returnType" : "long"
+    }
+  ],
+  "examples" : [
+    "FROM employees\n| STATS SUM(languages)",
+    "FROM employees\n| STATS total_salary_changes = SUM(MV_MAX(salary_change))"
+  ]
+}

+ 11 - 0
docs/reference/esql/functions/kibana/docs/sum.md

@@ -0,0 +1,11 @@
+<!--
+This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.
+-->
+
+### SUM
+The sum of a numeric expression.
+
+```
+FROM employees
+| STATS SUM(languages)
+```

+ 15 - 0
docs/reference/esql/functions/layout/sum.asciidoc

@@ -0,0 +1,15 @@
+// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.
+
+[discrete]
+[[esql-sum]]
+=== `SUM`
+
+*Syntax*
+
+[.text-center]
+image::esql/functions/signature/sum.svg[Embedded,opts=inline]
+
+include::../parameters/sum.asciidoc[]
+include::../description/sum.asciidoc[]
+include::../types/sum.asciidoc[]
+include::../examples/sum.asciidoc[]

+ 6 - 0
docs/reference/esql/functions/parameters/sum.asciidoc

@@ -0,0 +1,6 @@
+// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.
+
+*Parameters*
+
+`number`::
+

+ 1 - 0
docs/reference/esql/functions/signature/sum.svg

@@ -0,0 +1 @@
+<svg version="1.1" xmlns:xlink="http://www.w3.org/1999/xlink" xmlns="http://www.w3.org/2000/svg" width="252" height="46" viewbox="0 0 252 46"><defs><style type="text/css">#guide .c{fill:none;stroke:#222222;}#guide .k{fill:#000000;font-family:Roboto Mono,Sans-serif;font-size:20px;}#guide .s{fill:#e4f4ff;stroke:#222222;}#guide .syn{fill:#8D8D8D;font-family:Roboto Mono,Sans-serif;font-size:20px;}</style></defs><path class="c" d="M0 31h5m56 0h10m32 0h10m92 0h10m32 0h5"/><rect class="s" x="5" y="5" width="56" height="36"/><text class="k" x="15" y="31">SUM</text><rect class="s" x="71" y="5" width="32" height="36" rx="7"/><text class="syn" x="81" y="31">(</text><rect class="s" x="113" y="5" width="92" height="36" rx="7"/><text class="k" x="123" y="31">number</text><rect class="s" x="215" y="5" width="32" height="36" rx="7"/><text class="syn" x="225" y="31">)</text></svg>

+ 11 - 0
docs/reference/esql/functions/types/sum.asciidoc

@@ -0,0 +1,11 @@
+// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.
+
+*Supported types*
+
+[%header.monospaced.styled,format=dsv,separator=|]
+|===
+number | result
+double | double
+integer | long
+long | long
+|===

+ 3 - 3
x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec

@@ -81,7 +81,7 @@ double pi()
 "double st_y(point:geo_point|cartesian_point)"
 "boolean starts_with(str:keyword|text, prefix:keyword|text)"
 "keyword substring(string:keyword|text, start:integer, ?length:integer)"
-"long sum(number:double|integer|long)"
+"long|double sum(number:double|integer|long)"
 "double tan(angle:double|integer|long|unsigned_long)"
 "double tanh(angle:double|integer|long|unsigned_long)"
 double tau()
@@ -324,7 +324,7 @@ st_x          |Extracts the `x` coordinate from the supplied point. If the point
 st_y          |Extracts the `y` coordinate from the supplied point. If the points is of type `geo_point` this is equivalent to extracting the `latitude` value.
 starts_with   |Returns a boolean that indicates whether a keyword string starts with another string.
 substring     |Returns a substring of a string, specified by a start position and an optional length.
-sum           |The sum of a numeric field.
+sum           |The sum of a numeric expression.
 tan           |Returns the {wikipedia}/Sine_and_cosine[Tangent] trigonometric function of an angle.
 tanh          |Returns the {wikipedia}/Hyperbolic_functions[Tangent] hyperbolic function of an angle.
 tau           |Returns the https://tauday.com/tau-manifesto[ratio] of a circle's circumference to its radius.
@@ -447,7 +447,7 @@ st_x          |double
 st_y          |double                                                                                                                      |false                       |false           |false
 starts_with   |boolean                                                                                                                     |[false, false]              |false           |false
 substring     |keyword                                                                                                                     |[false, false, true]        |false           |false
-sum           |long                                                                                                                        |false                       |false           |true
+sum           |"long|double"                                                                                                               |false                       |false           |true
 tan           |double                                                                                                                      |false                       |false           |false
 tanh          |double                                                                                                                      |false                       |false           |false
 tau           |double                                                                                                                      |null                        |false           |false

+ 15 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java

@@ -19,6 +19,7 @@ import org.elasticsearch.xpack.esql.core.tree.Source;
 import org.elasticsearch.xpack.esql.core.type.DataType;
 import org.elasticsearch.xpack.esql.core.util.StringUtils;
 import org.elasticsearch.xpack.esql.expression.SurrogateExpression;
+import org.elasticsearch.xpack.esql.expression.function.Example;
 import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
 import org.elasticsearch.xpack.esql.expression.function.Param;
 import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvSum;
@@ -37,7 +38,20 @@ import static org.elasticsearch.xpack.esql.core.type.DataType.UNSIGNED_LONG;
 public class Sum extends NumericAggregate implements SurrogateExpression {
     public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Sum", Sum::new);
 
-    @FunctionInfo(returnType = "long", description = "The sum of a numeric field.", isAggregation = true)
+    @FunctionInfo(
+        returnType = { "long", "double" },
+        description = "The sum of a numeric expression.",
+        isAggregation = true,
+        examples = {
+            @Example(file = "stats", tag = "sum"),
+            @Example(
+                description = "The expression can use inline functions. For example, to calculate "
+                    + "the sum of each employee's maximum salary changes, apply the "
+                    + "`MV_MAX` function to each row and then sum the results",
+                file = "stats",
+                tag = "docsStatsSumNestedExpression"
+            ) }
+    )
     public Sum(Source source, @Param(name = "number", type = { "double", "integer", "long" }) Expression field) {
         super(source, field);
     }

+ 132 - 0
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumTests.java

@@ -0,0 +1,132 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.expression.function.aggregate;
+
+import com.carrotsearch.randomizedtesting.annotations.Name;
+import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
+
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.core.type.DataType;
+import org.elasticsearch.xpack.esql.expression.function.AbstractAggregationTestCase;
+import org.elasticsearch.xpack.esql.expression.function.MultiRowTestCaseSupplier;
+import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.function.Supplier;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+import static org.elasticsearch.xpack.esql.core.type.DataType.UNSIGNED_LONG;
+import static org.hamcrest.Matchers.equalTo;
+
+public class SumTests extends AbstractAggregationTestCase {
+    public SumTests(@Name("TestCase") Supplier<TestCaseSupplier.TestCase> testCaseSupplier) {
+        this.testCase = testCaseSupplier.get();
+    }
+
+    @ParametersFactory
+    public static Iterable<Object[]> parameters() {
+        var suppliers = new ArrayList<TestCaseSupplier>();
+
+        Stream.of(MultiRowTestCaseSupplier.intCases(1, 1000, Integer.MIN_VALUE, Integer.MAX_VALUE, true)
+        // Longs currently throw on overflow.
+        // Restore after https://github.com/elastic/elasticsearch/issues/110437
+        // MultiRowTestCaseSupplier.longCases(1, 1000, Long.MIN_VALUE, Long.MAX_VALUE, true),
+        // Doubles currently return +/-Infinity on overflow.
+        // Restore after https://github.com/elastic/elasticsearch/issues/111026
+        // MultiRowTestCaseSupplier.doubleCases(1, 1000, -Double.MAX_VALUE, Double.MAX_VALUE, true)
+        ).flatMap(List::stream).map(SumTests::makeSupplier).collect(Collectors.toCollection(() -> suppliers));
+
+        suppliers.addAll(
+            List.of(
+                // Folding
+                new TestCaseSupplier(
+                    List.of(DataType.INTEGER),
+                    () -> new TestCaseSupplier.TestCase(
+                        List.of(TestCaseSupplier.TypedData.multiRow(List.of(200), DataType.INTEGER, "field")),
+                        "Sum[field=Attribute[channel=0]]",
+                        DataType.LONG,
+                        equalTo(200L)
+                    )
+                ),
+                new TestCaseSupplier(
+                    List.of(DataType.LONG),
+                    () -> new TestCaseSupplier.TestCase(
+                        List.of(TestCaseSupplier.TypedData.multiRow(List.of(200L), DataType.LONG, "field")),
+                        "Sum[field=Attribute[channel=0]]",
+                        DataType.LONG,
+                        equalTo(200L)
+                    )
+                ),
+                new TestCaseSupplier(
+                    List.of(DataType.DOUBLE),
+                    () -> new TestCaseSupplier.TestCase(
+                        List.of(TestCaseSupplier.TypedData.multiRow(List.of(200.), DataType.DOUBLE, "field")),
+                        "Sum[field=Attribute[channel=0]]",
+                        DataType.DOUBLE,
+                        equalTo(200.)
+                    )
+                )
+            )
+        );
+
+        return parameterSuppliersFromTypedDataWithDefaultChecks(suppliers);
+    }
+
+    @Override
+    protected Expression build(Source source, List<Expression> args) {
+        return new Sum(source, args.get(0));
+    }
+
+    private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier fieldSupplier) {
+        return new TestCaseSupplier(List.of(fieldSupplier.type()), () -> {
+            var fieldTypedData = fieldSupplier.get();
+
+            Object expected;
+
+            try {
+                expected = switch (fieldTypedData.type().widenSmallNumeric()) {
+                    case INTEGER -> fieldTypedData.multiRowData()
+                        .stream()
+                        .map(v -> (Integer) v)
+                        .collect(Collectors.summarizingInt(Integer::intValue))
+                        .getSum();
+                    case LONG -> fieldTypedData.multiRowData()
+                        .stream()
+                        .map(v -> (Long) v)
+                        .collect(Collectors.summarizingLong(Long::longValue))
+                        .getSum();
+                    case DOUBLE -> {
+                        var value = fieldTypedData.multiRowData()
+                            .stream()
+                            .map(v -> (Double) v)
+                            .collect(Collectors.summarizingDouble(Double::doubleValue))
+                            .getSum();
+
+                        if (Double.isInfinite(value) || Double.isNaN(value)) {
+                            yield null;
+                        }
+
+                        yield value;
+                    }
+                    default -> throw new IllegalStateException("Unexpected value: " + fieldTypedData.type());
+                };
+            } catch (Exception e) {
+                expected = null;
+            }
+
+            var dataType = fieldTypedData.type().isWholeNumber() == false || fieldTypedData.type() == UNSIGNED_LONG
+                ? DataType.DOUBLE
+                : DataType.LONG;
+
+            return new TestCaseSupplier.TestCase(List.of(fieldTypedData), "Sum[field=Attribute[channel=0]]", dataType, equalTo(expected));
+        });
+    }
+}