Browse Source

[ES|QL] weighted_avg (#109993)

* weighted_avg
Fang Xing 1 year ago
parent
commit
8abc8857f2

+ 5 - 0
docs/changelog/109993.yaml

@@ -0,0 +1,5 @@
+pr: 109993
+summary: "[ES|QL] `weighted_avg`"
+area: ES|QL
+type: enhancement
+issues: []

+ 2 - 0
docs/reference/esql/functions/aggregation-functions.asciidoc

@@ -20,6 +20,7 @@ The <<esql-stats-by>> command supports these aggregate functions:
 * <<esql-agg-sum>>
 * <<esql-top>>
 * <<esql-agg-values>>
+* experimental:[] <<esql-agg-weighted-avg>>
 // end::agg_list[]
 
 include::avg.asciidoc[]
@@ -34,3 +35,4 @@ include::st_centroid_agg.asciidoc[]
 include::sum.asciidoc[]
 include::layout/top.asciidoc[]
 include::values.asciidoc[]
+include::weighted-avg.asciidoc[]

+ 35 - 0
docs/reference/esql/functions/weighted-avg.asciidoc

@@ -0,0 +1,35 @@
+[discrete]
+[[esql-agg-weighted-avg]]
+=== `WEIGHTED_AVG`
+
+*Syntax*
+
+[source,esql]
+----
+WEIGHTED_AVG(expression, weight)
+----
+
+`expression`::
+Numeric expression.
+
+`weight`::
+Numeric weight.
+
+*Description*
+
+The weighted average of a numeric expression.
+
+*Supported types*
+
+The result is always a `double` no matter the input type.
+
+*Examples*
+
+[source.merge.styled,esql]
+----
+include::{esql-specs}/stats.csv-spec[tag=weighted-avg]
+----
+[%header.monospaced.styled,format=dsv,separator=|]
+|===
+include::{esql-specs}/stats.csv-spec[tag=weighted-avg-result]
+|===

+ 5 - 1
x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec

@@ -113,6 +113,7 @@ double tau()
 "double|integer|long|date top(field:double|integer|long|date, limit:integer, order:keyword)"
 "keyword|text trim(string:keyword|text)"
 "boolean|date|double|integer|ip|keyword|long|text|version values(field:boolean|date|double|integer|ip|keyword|long|text|version)"
+"double weighted_avg(number:double|integer|long, weight:double|integer|long)"
 ;
 
 metaFunctionsArgs#[skip:-8.14.99]
@@ -232,6 +233,7 @@ to_version    |field                               |"keyword|text|version"
 top      |[field, limit, order]               |["double|integer|long|date", integer, keyword]                                                                                    |[The field to collect the top values for.,The maximum number of values to collect.,The order to calculate the top values. Either `asc` or `desc`.]
 trim          |string                              |"keyword|text"                                                                                                                    |String expression. If `null`, the function returns `null`.
 values        |field                               |"boolean|date|double|integer|ip|keyword|long|text|version"                                                                        |[""]
+weighted_avg  |[number, weight]                    |["double|integer|long", "double|integer|long"]                                                                                    |[A numeric value., A numeric weight.]
 ;
 
 metaFunctionsDescription#[skip:-8.14.99]
@@ -352,6 +354,7 @@ to_version    |Converts an input string to a version value.
 top      |Collects the top values for a field. Includes repeated values.
 trim          |Removes leading and trailing whitespaces from a string.
 values        |Collect values for a field.
+weighted_avg  |The weighted average of a numeric field.
 ;
 
 metaFunctionsRemaining#[skip:-8.14.99]
@@ -473,6 +476,7 @@ to_version    |version
 top      |"double|integer|long|date"                                                                                                  |[false, false, false]       |false           |true
 trim          |"keyword|text"                                                                                                              |false                       |false           |false
 values        |"boolean|date|double|integer|ip|keyword|long|text|version"                                                                  |false                       |false           |true
+weighted_avg  |"double"                                                                                                                    |[false, false]              |false           |true
 ;
 
 metaFunctionsFiltered#[skip:-8.14.99]
@@ -491,5 +495,5 @@ countFunctions#[skip:-8.14.99, reason:BIN added]
 meta functions |  stats  a = count(*), b = count(*), c = count(*) |  mv_expand c;
 
 a:long | b:long | c:long
-111    | 111    | 111
+112    | 112    | 112
 ;

+ 122 - 0
x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec

@@ -1697,3 +1697,125 @@ FROM employees | STATS min = min(salary) by languages | SORT min + CASE(language
 29175          |2              
 28336          |null           
 ;
+
+
+weightedAvg
+required_capability: agg_weighted_avg
+from employees
+| stats w_avg_1 = weighted_avg(salary, 1), avg = avg(salary), w_avg_2 = weighted_avg(salary, height)
+| EVAL w_avg_1 = ROUND(w_avg_1), avg = ROUND(avg), w_avg_2 = ROUND(w_avg_2)
+;
+
+w_avg_1:double | avg:double | w_avg_2:double
+48249.0        | 48249.0   | 48474.0
+;
+
+weightedAvgGrouping
+required_capability: agg_weighted_avg
+// tag::weighted-avg[]
+FROM employees
+| STATS w_avg = WEIGHTED_AVG(salary, height) by languages
+| EVAL w_avg = ROUND(w_avg)
+| KEEP w_avg, languages
+| SORT languages
+// end::weighted-avg[]
+;
+
+// tag::weighted-avg-result[]
+w_avg:double | languages:integer
+51464.0      | 1
+48477.0      | 2
+52379.0      | 3
+47990.0      | 4
+42119.0      | 5
+52142.0      | null
+// end::weighted-avg-result[]
+;
+
+weightedAvgConstant
+required_capability: agg_weighted_avg
+row v = [1, 2, 3]
+| stats w_avg_1 = weighted_avg(v, 1), w_avg_2 = weighted_avg([1, 2, 3], 1), avg = avg(v)
+| EVAL w_avg_1 = ROUND(w_avg_1), w_avg_2 = ROUND(w_avg_2), avg = ROUND(avg)
+;
+
+w_avg_1:double |w_avg_2:double |avg:double
+2.0            | 2.0           | 2.0
+;
+
+weightedAvgBothConstantsMvWarning
+required_capability: agg_weighted_avg
+row v = [1, 2, 3], w = [1, 2, 3]
+| stats w_avg = weighted_avg(v, w)
+;
+warning:Line 2:17: evaluation of [weighted_avg(v, w)] failed, treating result as null. Only first 20 failures recorded.
+warning:Line 2:17: java.lang.IllegalArgumentException: single-value function encountered multi-value
+
+w_avg:double
+null
+;
+
+weightedAvgWeightConstantMvWarning
+required_capability: agg_weighted_avg
+from employees
+| eval w = [1, 2, 3]
+| stats w_avg = weighted_avg(salary, w)
+;
+warning:Line 3:17: evaluation of [weighted_avg(salary, w)] failed, treating result as null. Only first 20 failures recorded.
+warning:Line 3:17: java.lang.IllegalArgumentException: single-value function encountered multi-value
+
+w_avg:double
+null
+;
+
+weightedAvgWeightMvWarning
+required_capability: agg_weighted_avg
+from employees
+| where emp_no == 10002 or emp_no == 10003
+| stats w_avg = weighted_avg(salary, salary_change.int)
+;
+warning:Line 3:17: evaluation of [weighted_avg(salary, salary_change.int)] failed, treating result as null. Only first 20 failures recorded.
+warning:Line 3:17: java.lang.IllegalArgumentException: single-value function encountered multi-value
+
+w_avg:double
+null
+;
+
+weightedAvgFieldMvWarning
+required_capability: agg_weighted_avg
+from employees
+| where emp_no == 10002 or emp_no == 10003
+| stats w_avg = weighted_avg(salary_change.int, height)
+;
+warning:Line 3:17: evaluation of [weighted_avg(salary_change.int, height)] failed, treating result as null. Only first 20 failures recorded.
+warning:Line 3:17: java.lang.IllegalArgumentException: single-value function encountered multi-value
+
+w_avg:double
+null
+;
+
+weightedAvgWeightZero
+required_capability: agg_weighted_avg
+from employees
+| eval w = 0
+| stats w_avg = weighted_avg(salary, w)
+;
+warning:Line 3:17: evaluation of [weighted_avg(salary, w)] failed, treating result as null. Only first 20 failures recorded.
+warning:Line 3:17: java.lang.ArithmeticException: / by zero
+
+w_avg:double
+null
+;
+
+weightedAvgWeightZeroExp
+required_capability: agg_weighted_avg
+from employees
+| eval w = 0 + 0
+| stats w_avg = weighted_avg(salary, w)
+;
+warning:Line 3:17: evaluation of [weighted_avg(salary, w)] failed, treating result as null. Only first 20 failures recorded.
+warning:Line 3:17: java.lang.ArithmeticException: / by zero
+
+w_avg:double
+null
+;

+ 6 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java

@@ -101,7 +101,12 @@ public class EsqlCapabilities {
         /**
          * Support for quoting index sources in double quotes.
          */
-        DOUBLE_QUOTES_SOURCE_ENCLOSING;
+        DOUBLE_QUOTES_SOURCE_ENCLOSING,
+
+        /**
+         * Support for WEIGHTED_AVG function.
+         */
+        AGG_WEIGHTED_AVG;
 
         private final boolean snapshotOnly;
 

+ 3 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java

@@ -26,6 +26,7 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroi
 import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
 import org.elasticsearch.xpack.esql.expression.function.aggregate.Top;
 import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;
+import org.elasticsearch.xpack.esql.expression.function.aggregate.WeightedAvg;
 import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
 import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
 import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest;
@@ -199,7 +200,8 @@ public final class EsqlFunctionRegistry extends FunctionRegistry {
                 def(Percentile.class, Percentile::new, "percentile"),
                 def(Sum.class, Sum::new, "sum"),
                 def(Top.class, Top::new, "top"),
-                def(Values.class, Values::new, "values") },
+                def(Values.class, Values::new, "values"),
+                def(WeightedAvg.class, WeightedAvg::new, "weighted_avg") },
             // math
             new FunctionDefinition[] {
                 def(Abs.class, Abs::new, "abs"),

+ 2 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java

@@ -45,7 +45,8 @@ public abstract class AggregateFunction extends Function {
             Values.ENTRY,
             // internal functions
             ToPartial.ENTRY,
-            FromPartial.ENTRY
+            FromPartial.ENTRY,
+            WeightedAvg.ENTRY
         );
     }
 

+ 145 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvg.java

@@ -0,0 +1,145 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.expression.function.aggregate;
+
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.xpack.esql.capabilities.Validatable;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.core.type.DataType;
+import org.elasticsearch.xpack.esql.expression.SurrogateExpression;
+import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
+import org.elasticsearch.xpack.esql.expression.function.Param;
+import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvAvg;
+import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div;
+import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul;
+import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
+
+import java.io.IOException;
+import java.util.List;
+
+import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
+import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
+import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND;
+import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;
+
+public class WeightedAvg extends AggregateFunction implements SurrogateExpression, Validatable {
+    public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
+        Expression.class,
+        "WeightedAvg",
+        WeightedAvg::new
+    );
+
+    private final Expression weight;
+
+    private static final String invalidWeightError = "{} argument of [{}] cannot be null or 0, received [{}]";
+
+    @FunctionInfo(returnType = "double", description = "The weighted average of a numeric field.", isAggregation = true)
+    public WeightedAvg(
+        Source source,
+        @Param(name = "number", type = { "double", "integer", "long" }, description = "A numeric value.") Expression field,
+        @Param(name = "weight", type = { "double", "integer", "long" }, description = "A numeric weight.") Expression weight
+    ) {
+        super(source, field, List.of(weight));
+        this.weight = weight;
+    }
+
+    private WeightedAvg(StreamInput in) throws IOException {
+        this(Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class), in.readNamedWriteable(Expression.class));
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        source().writeTo(out);
+        List<Expression> fields = children();
+        assert fields.size() == 2;
+        out.writeNamedWriteable(fields.get(0));
+        out.writeNamedWriteable(fields.get(1));
+    }
+
+    @Override
+    public String getWriteableName() {
+        return ENTRY.name;
+    }
+
+    @Override
+    protected Expression.TypeResolution resolveType() {
+        if (childrenResolved() == false) {
+            return new TypeResolution("Unresolved children");
+        }
+
+        TypeResolution resolution = isType(
+            field(),
+            dt -> dt.isNumeric() && dt != DataType.UNSIGNED_LONG,
+            sourceText(),
+            FIRST,
+            "numeric except unsigned_long or counter types"
+        );
+
+        if (resolution.unresolved()) {
+            return resolution;
+        }
+
+        resolution = isType(
+            weight(),
+            dt -> dt.isNumeric() && dt != DataType.UNSIGNED_LONG,
+            sourceText(),
+            SECOND,
+            "numeric except unsigned_long or counter types"
+        );
+
+        if (resolution.unresolved()) {
+            return resolution;
+        }
+
+        if (weight.dataType() == DataType.NULL
+            || (weight.foldable() && (weight.fold() == null || weight.fold().equals(0) || weight.fold().equals(0.0)))) {
+            return new TypeResolution(format(null, invalidWeightError, SECOND, sourceText(), weight.foldable() ? weight.fold() : null));
+        }
+
+        return TypeResolution.TYPE_RESOLVED;
+    }
+
+    @Override
+    public DataType dataType() {
+        return DataType.DOUBLE;
+    }
+
+    @Override
+    protected NodeInfo<WeightedAvg> info() {
+        return NodeInfo.create(this, WeightedAvg::new, field(), weight);
+    }
+
+    @Override
+    public WeightedAvg replaceChildren(List<Expression> newChildren) {
+        return new WeightedAvg(source(), newChildren.get(0), newChildren.get(1));
+    }
+
+    @Override
+    public Expression surrogate() {
+        var s = source();
+        var field = field();
+        var weight = weight();
+
+        if (field.foldable()) {
+            return new MvAvg(s, field);
+        }
+        if (weight.foldable()) {
+            return new Div(s, new Sum(s, field), new Count(s, field), dataType());
+        } else {
+            return new Div(s, new Sum(s, new Mul(s, field, weight)), new Sum(s, weight), dataType());
+        }
+    }
+
+    public Expression weight() {
+        return weight;
+    }
+}

+ 31 - 0
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java

@@ -595,6 +595,37 @@ public class VerifierTests extends ESTestCase {
              and inside another aggregate"""));
     }
 
+    public void testWeightedAvg() {
+        assertEquals(
+            "1:35: SECOND argument of [weighted_avg(v, null)] cannot be null or 0, received [null]",
+            error("row v = [1, 2, 3] | stats w_avg = weighted_avg(v, null)")
+        );
+        assertEquals(
+            "1:27: SECOND argument of [weighted_avg(salary, null)] cannot be null or 0, received [null]",
+            error("from test | stats w_avg = weighted_avg(salary, null)")
+        );
+        assertEquals(
+            "1:45: SECOND argument of [weighted_avg(v, w)] cannot be null or 0, received [null]",
+            error("row v = [1, 2, 3], w = null | stats w_avg = weighted_avg(v, w)")
+        );
+        assertEquals(
+            "1:44: SECOND argument of [weighted_avg(salary, w)] cannot be null or 0, received [null]",
+            error("from test | eval w = null |  stats w_avg = weighted_avg(salary, w)")
+        );
+        assertEquals(
+            "1:51: SECOND argument of [weighted_avg(salary, w)] cannot be null or 0, received [null]",
+            error("from test | eval w = null + null |  stats w_avg = weighted_avg(salary, w)")
+        );
+        assertEquals(
+            "1:35: SECOND argument of [weighted_avg(v, 0)] cannot be null or 0, received [0]",
+            error("row v = [1, 2, 3] | stats w_avg = weighted_avg(v, 0)")
+        );
+        assertEquals(
+            "1:27: SECOND argument of [weighted_avg(salary, 0.0)] cannot be null or 0, received [0.0]",
+            error("from test | stats w_avg = weighted_avg(salary, 0.0)")
+        );
+    }
+
     private String error(String query) {
         return error(query, defaultAnalyzer);
     }