Просмотр исходного кода

ESQL: AVG aggregation tests and ignore complex surrogates (#110579)

Some work around aggregation tests, with AVG as an example:
- Added tests and autogenerated docs for AVG
- As AVG uses "complex" surrogates (A combination of functions), we can't trivially execute them without a complete plan. As I'm not sure it's worth it for most aggregations, I'm skipping those cases for now, as to avoid blocking other aggs tests.

The bad side effect of skipping those tests is that most tests in AvgTests are actually ignored (74 of 100)
Iván Cea Fontenla 1 год назад
Родитель
Сommit
38cd0b333e

+ 2 - 2
docs/reference/esql/functions/aggregation-functions.asciidoc

@@ -8,7 +8,7 @@
 The <<esql-stats-by>> command supports these aggregate functions:
 
 // tag::agg_list[]
-* <<esql-agg-avg>>
+* <<esql-avg>>
 * <<esql-agg-count>>
 * <<esql-agg-count-distinct>>
 * <<esql-agg-max>>
@@ -23,7 +23,6 @@ The <<esql-stats-by>> command supports these aggregate functions:
 * experimental:[] <<esql-agg-weighted-avg>>
 // end::agg_list[]
 
-include::avg.asciidoc[]
 include::count.asciidoc[]
 include::count-distinct.asciidoc[]
 include::max.asciidoc[]
@@ -33,6 +32,7 @@ include::min.asciidoc[]
 include::percentile.asciidoc[]
 include::st_centroid_agg.asciidoc[]
 include::sum.asciidoc[]
+include::layout/avg.asciidoc[]
 include::layout/top.asciidoc[]
 include::values.asciidoc[]
 include::weighted-avg.asciidoc[]

+ 0 - 47
docs/reference/esql/functions/avg.asciidoc

@@ -1,47 +0,0 @@
-[discrete]
-[[esql-agg-avg]]
-=== `AVG`
-
-*Syntax*
-
-[source,esql]
-----
-AVG(expression)
-----
-
-`expression`::
-Numeric expression.
-//If `null`, the function returns `null`.
-// TODO: Remove comment when https://github.com/elastic/elasticsearch/issues/104900 is fixed.
-
-*Description*
-
-The average of a numeric expression.
-
-*Supported types*
-
-The result is always a `double` no matter the input type.
-
-*Examples*
-
-[source.merge.styled,esql]
-----
-include::{esql-specs}/stats.csv-spec[tag=avg]
-----
-[%header.monospaced.styled,format=dsv,separator=|]
-|===
-include::{esql-specs}/stats.csv-spec[tag=avg-result]
-|===
-
-The expression can use inline functions. For example, to calculate the average
-over a multivalued column, first use `MV_AVG` to average the multiple values per
-row, and use the result with the `AVG` function:
-
-[source.merge.styled,esql]
-----
-include::{esql-specs}/stats.csv-spec[tag=docsStatsAvgNestedExpression]
-----
-[%header.monospaced.styled,format=dsv,separator=|]
-|===
-include::{esql-specs}/stats.csv-spec[tag=docsStatsAvgNestedExpression-result]
-|===

+ 5 - 0
docs/reference/esql/functions/description/avg.asciidoc

@@ -0,0 +1,5 @@
+// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.
+
+*Description*
+
+The average of a numeric field.

+ 22 - 0
docs/reference/esql/functions/examples/avg.asciidoc

@@ -0,0 +1,22 @@
+// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.
+
+*Examples*
+
+[source.merge.styled,esql]
+----
+include::{esql-specs}/stats.csv-spec[tag=avg]
+----
+[%header.monospaced.styled,format=dsv,separator=|]
+|===
+include::{esql-specs}/stats.csv-spec[tag=avg-result]
+|===
+The expression can use inline functions. For example, to calculate the average over a multivalued column, first use `MV_AVG` to average the multiple values per row, and use the result with the `AVG` function
+[source.merge.styled,esql]
+----
+include::{esql-specs}/stats.csv-spec[tag=docsStatsAvgNestedExpression]
+----
+[%header.monospaced.styled,format=dsv,separator=|]
+|===
+include::{esql-specs}/stats.csv-spec[tag=docsStatsAvgNestedExpression-result]
+|===
+

+ 48 - 0
docs/reference/esql/functions/kibana/definition/avg.json

@@ -0,0 +1,48 @@
+{
+  "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.",
+  "type" : "agg",
+  "name" : "avg",
+  "description" : "The average of a numeric field.",
+  "signatures" : [
+    {
+      "params" : [
+        {
+          "name" : "number",
+          "type" : "double",
+          "optional" : false,
+          "description" : ""
+        }
+      ],
+      "variadic" : false,
+      "returnType" : "double"
+    },
+    {
+      "params" : [
+        {
+          "name" : "number",
+          "type" : "integer",
+          "optional" : false,
+          "description" : ""
+        }
+      ],
+      "variadic" : false,
+      "returnType" : "double"
+    },
+    {
+      "params" : [
+        {
+          "name" : "number",
+          "type" : "long",
+          "optional" : false,
+          "description" : ""
+        }
+      ],
+      "variadic" : false,
+      "returnType" : "double"
+    }
+  ],
+  "examples" : [
+    "FROM employees\n| STATS AVG(height)",
+    "FROM employees\n| STATS avg_salary_change = ROUND(AVG(MV_AVG(salary_change)), 10)"
+  ]
+}

+ 11 - 0
docs/reference/esql/functions/kibana/docs/avg.md

@@ -0,0 +1,11 @@
+<!--
+This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.
+-->
+
+### AVG
+The average of a numeric field.
+
+```
+FROM employees
+| STATS AVG(height)
+```

+ 15 - 0
docs/reference/esql/functions/layout/avg.asciidoc

@@ -0,0 +1,15 @@
+// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.
+
+[discrete]
+[[esql-avg]]
+=== `AVG`
+
+*Syntax*
+
+[.text-center]
+image::esql/functions/signature/avg.svg[Embedded,opts=inline]
+
+include::../parameters/avg.asciidoc[]
+include::../description/avg.asciidoc[]
+include::../types/avg.asciidoc[]
+include::../examples/avg.asciidoc[]

+ 6 - 0
docs/reference/esql/functions/parameters/avg.asciidoc

@@ -0,0 +1,6 @@
+// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.
+
+*Parameters*
+
+`number`::
+

+ 1 - 0
docs/reference/esql/functions/signature/avg.svg

@@ -0,0 +1 @@
+<svg version="1.1" xmlns:xlink="http://www.w3.org/1999/xlink" xmlns="http://www.w3.org/2000/svg" width="252" height="46" viewbox="0 0 252 46"><defs><style type="text/css">#guide .c{fill:none;stroke:#222222;}#guide .k{fill:#000000;font-family:Roboto Mono,Sans-serif;font-size:20px;}#guide .s{fill:#e4f4ff;stroke:#222222;}#guide .syn{fill:#8D8D8D;font-family:Roboto Mono,Sans-serif;font-size:20px;}</style></defs><path class="c" d="M0 31h5m56 0h10m32 0h10m92 0h10m32 0h5"/><rect class="s" x="5" y="5" width="56" height="36"/><text class="k" x="15" y="31">AVG</text><rect class="s" x="71" y="5" width="32" height="36" rx="7"/><text class="syn" x="81" y="31">(</text><rect class="s" x="113" y="5" width="92" height="36" rx="7"/><text class="k" x="123" y="31">number</text><rect class="s" x="215" y="5" width="32" height="36" rx="7"/><text class="syn" x="225" y="31">)</text></svg>

+ 11 - 0
docs/reference/esql/functions/types/avg.asciidoc

@@ -0,0 +1,11 @@
+// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.
+
+*Supported types*
+
+[%header.monospaced.styled,format=dsv,separator=|]
+|===
+number | result
+double | double
+integer | double
+long | double
+|===

+ 15 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Avg.java

@@ -14,6 +14,7 @@ import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
 import org.elasticsearch.xpack.esql.core.tree.Source;
 import org.elasticsearch.xpack.esql.core.type.DataType;
 import org.elasticsearch.xpack.esql.expression.SurrogateExpression;
+import org.elasticsearch.xpack.esql.expression.function.Example;
 import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
 import org.elasticsearch.xpack.esql.expression.function.Param;
 import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvAvg;
@@ -28,7 +29,20 @@ import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isTyp
 public class Avg extends AggregateFunction implements SurrogateExpression {
     public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Avg", Avg::new);
 
-    @FunctionInfo(returnType = "double", description = "The average of a numeric field.", isAggregation = true)
+    @FunctionInfo(
+        returnType = "double",
+        description = "The average of a numeric field.",
+        isAggregation = true,
+        examples = {
+            @Example(file = "stats", tag = "avg"),
+            @Example(
+                description = "The expression can use inline functions. For example, to calculate the average "
+                    + "over a multivalued column, first use `MV_AVG` to average the multiple values per row, "
+                    + "and use the result with the `AVG` function",
+                file = "stats",
+                tag = "docsStatsAvgNestedExpression"
+            ) }
+    )
     public Avg(Source source, @Param(name = "number", type = { "double", "integer", "long" }) Expression field) {
         super(source, field);
     }

+ 9 - 0
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java

@@ -15,6 +15,8 @@ import org.elasticsearch.compute.data.ElementType;
 import org.elasticsearch.compute.data.Page;
 import org.elasticsearch.core.Releasables;
 import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
+import org.elasticsearch.xpack.esql.core.expression.Literal;
 import org.elasticsearch.xpack.esql.core.type.DataType;
 import org.elasticsearch.xpack.esql.core.util.NumericUtils;
 import org.elasticsearch.xpack.esql.expression.SurrogateExpression;
@@ -251,6 +253,13 @@ public abstract class AbstractAggregationTestCase extends AbstractFunctionTestCa
         expression = new FoldNull().rule(expression);
         assertThat(expression.dataType(), equalTo(testCase.expectedType()));
 
+        assumeTrue(
+            "Surrogate expression with non-trivial children cannot be evaluated",
+            expression.children()
+                .stream()
+                .allMatch(child -> child instanceof FieldAttribute || child instanceof DeepCopy || child instanceof Literal)
+        );
+
         if (expression instanceof AggregateFunction == false) {
             onEvaluableExpression.accept(expression);
             return;

+ 95 - 0
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgTests.java

@@ -0,0 +1,95 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.expression.function.aggregate;
+
+import com.carrotsearch.randomizedtesting.annotations.Name;
+import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
+
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.core.type.DataType;
+import org.elasticsearch.xpack.esql.expression.function.AbstractAggregationTestCase;
+import org.elasticsearch.xpack.esql.expression.function.MultiRowTestCaseSupplier;
+import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.function.Supplier;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public class AvgTests extends AbstractAggregationTestCase {
+    public AvgTests(@Name("TestCase") Supplier<TestCaseSupplier.TestCase> testCaseSupplier) {
+        this.testCase = testCaseSupplier.get();
+    }
+
+    @ParametersFactory
+    public static Iterable<Object[]> parameters() {
+        var suppliers = new ArrayList<TestCaseSupplier>();
+
+        Stream.of(
+            MultiRowTestCaseSupplier.intCases(1, 1000, Integer.MIN_VALUE, Integer.MAX_VALUE, true),
+            MultiRowTestCaseSupplier.longCases(1, 1000, Long.MIN_VALUE, Long.MAX_VALUE, true),
+            MultiRowTestCaseSupplier.doubleCases(1, 1000, -Double.MAX_VALUE, Double.MAX_VALUE, true)
+        ).flatMap(List::stream).map(AvgTests::makeSupplier).collect(Collectors.toCollection(() -> suppliers));
+
+        suppliers.add(
+            // Folding
+            new TestCaseSupplier(
+                List.of(DataType.INTEGER),
+                () -> new TestCaseSupplier.TestCase(
+                    List.of(TestCaseSupplier.TypedData.multiRow(List.of(200), DataType.INTEGER, "field")),
+                    "Avg[field=Attribute[channel=0]]",
+                    DataType.DOUBLE,
+                    equalTo(200.)
+                )
+            )
+        );
+
+        return parameterSuppliersFromTypedDataWithDefaultChecks(suppliers);
+    }
+
+    @Override
+    protected Expression build(Source source, List<Expression> args) {
+        return new Avg(source, args.get(0));
+    }
+
+    private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier fieldSupplier) {
+        return new TestCaseSupplier(List.of(fieldSupplier.type()), () -> {
+            var fieldTypedData = fieldSupplier.get();
+
+            Object expected = switch (fieldTypedData.type().widenSmallNumeric()) {
+                case INTEGER -> fieldTypedData.multiRowData()
+                    .stream()
+                    .map(v -> (Integer) v)
+                    .collect(Collectors.summarizingInt(Integer::intValue))
+                    .getAverage();
+                case LONG -> fieldTypedData.multiRowData()
+                    .stream()
+                    .map(v -> (Long) v)
+                    .collect(Collectors.summarizingLong(Long::longValue))
+                    .getAverage();
+                case DOUBLE -> fieldTypedData.multiRowData()
+                    .stream()
+                    .map(v -> (Double) v)
+                    .collect(Collectors.summarizingDouble(Double::doubleValue))
+                    .getAverage();
+                default -> throw new IllegalStateException("Unexpected value: " + fieldTypedData.type());
+            };
+
+            return new TestCaseSupplier.TestCase(
+                List.of(fieldTypedData),
+                "Avg[field=Attribute[channel=0]]",
+                DataType.DOUBLE,
+                equalTo(expected)
+            );
+        });
+    }
+}

+ 6 - 4
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopTests.java

@@ -22,6 +22,7 @@ import java.util.ArrayList;
 import java.util.Comparator;
 import java.util.List;
 import java.util.function.Supplier;
+import java.util.stream.Collectors;
 import java.util.stream.Stream;
 
 import static org.hamcrest.Matchers.equalTo;
@@ -37,14 +38,15 @@ public class TopTests extends AbstractAggregationTestCase {
 
         for (var limitCaseSupplier : TestCaseSupplier.intCases(1, 1000, false)) {
             for (String order : List.of("asc", "desc")) {
-                for (var fieldCaseSupplier : Stream.of(
+                Stream.of(
                     MultiRowTestCaseSupplier.intCases(1, 1000, Integer.MIN_VALUE, Integer.MAX_VALUE, true),
                     MultiRowTestCaseSupplier.longCases(1, 1000, Long.MIN_VALUE, Long.MAX_VALUE, true),
                     MultiRowTestCaseSupplier.doubleCases(1, 1000, -Double.MAX_VALUE, Double.MAX_VALUE, true),
                     MultiRowTestCaseSupplier.dateCases(1, 1000)
-                ).flatMap(List::stream).toList()) {
-                    suppliers.add(TopTests.makeSupplier(fieldCaseSupplier, limitCaseSupplier, order));
-                }
+                )
+                    .flatMap(List::stream)
+                    .map(fieldCaseSupplier -> TopTests.makeSupplier(fieldCaseSupplier, limitCaseSupplier, order))
+                    .collect(Collectors.toCollection(() -> suppliers));
             }
         }