Browse Source

Implement `mv_count` (ESQL-1126)

Implements the `mv_count` function which returns a count of the values
in a column.
Nik Everett 2 years ago
parent
commit
b7b9f71a49

+ 2 - 0
docs/reference/esql/esql-functions.asciidoc

@@ -22,6 +22,7 @@ these functions:
 * <<esql-is_null>>
 * <<esql-length>>
 * <<esql-mv_avg>>
+* <<esql-mv_count>>
 * <<esql-mv_max>>
 * <<esql-mv_min>>
 * <<esql-mv_sum>>
@@ -43,6 +44,7 @@ include::functions/is_nan.asciidoc[]
 include::functions/is_null.asciidoc[]
 include::functions/length.asciidoc[]
 include::functions/mv_avg.asciidoc[]
+include::functions/mv_count.asciidoc[]
 include::functions/mv_max.asciidoc[]
 include::functions/mv_min.asciidoc[]
 include::functions/mv_sum.asciidoc[]

+ 12 - 0
docs/reference/esql/functions/mv_count.asciidoc

@@ -0,0 +1,12 @@
+[[esql-mv_count]]
+=== `MV_COUNT`
+Converts a multivalued field into a single valued field containing a count of the number
+of values:
+
+[source,esql]
+----
+include::{esql-specs}/string.csv-spec[tag=mv_count]
+include::{esql-specs}/string.csv-spec[tag=mv_count-result]
+----
+
+NOTE: This function accepts all types and always returns an `integer`.

+ 2 - 0
x-pack/plugin/esql/qa/server/single-node/src/yamlRestTest/resources/rest-api-spec/test/10_basic.yml

@@ -298,6 +298,7 @@ setup:
         - median_absolute_deviation
         - min
         - mv_avg
+        - mv_count
         - mv_max
         - mv_min
         - mv_sum
@@ -328,6 +329,7 @@ setup:
         - median_absolute_deviation(arg1)
         - min(arg1)
         - mv_avg(arg1)
+        - mv_count(arg1)
         - mv_max(arg1)
         - mv_min(arg1)
         - mv_sum(arg1)

+ 9 - 0
x-pack/plugin/esql/qa/testFixtures/src/main/resources/math.csv-spec

@@ -213,6 +213,15 @@ ROW a=[3, 5, 1, 6]
 // end::mv_avg-result[]
 ;
 
+mvCount
+ROW a=[3, 5, 1, 6]
+| EVAL count_a = MV_COUNT(a)
+;
+
+   a:integer | count_a:integer
+[3, 5, 1, 6] | 4
+;
+
 
 mvMax
 from employees | where emp_no > 10008 | eval salary_change = mv_max(salary_change.int) | sort emp_no | project emp_no, salary_change.int, salary_change | limit 7;

+ 1 - 0
x-pack/plugin/esql/qa/testFixtures/src/main/resources/show.csv-spec

@@ -28,6 +28,7 @@ median                   |median(arg1)
 median_absolute_deviation|median_absolute_deviation(arg1)
 min                      |min(arg1)
 mv_avg                   |mv_avg(arg1)
+mv_count                 |mv_count(arg1)
 mv_max                   |mv_max(arg1)
 mv_min                   |mv_min(arg1)
 mv_sum                   |mv_sum(arg1)

+ 13 - 0
x-pack/plugin/esql/qa/testFixtures/src/main/resources/string.csv-spec

@@ -210,6 +210,19 @@ foo;bar;baz;qux;quux;corge | [foo,bar,baz,qux,quux,corge]
 // end::split-result[]
 ;
 
+mvCount
+// tag::mv_count[]
+ROW a=["foo", "zoo", "bar"]
+| EVAL count_a = MV_COUNT(a)
+// end::mv_count[]
+;
+
+// tag::mv_count-result[]
+            a:keyword | count_a:integer
+["foo", "zoo", "bar"] | 3
+// end::mv_count-result[]
+;
+
 mvMax
 // tag::mv_max[]
 ROW a=["foo", "zoo", "bar"]

+ 2 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java

@@ -27,6 +27,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.math.IsNaN;
 import org.elasticsearch.xpack.esql.expression.function.scalar.math.Pow;
 import org.elasticsearch.xpack.esql.expression.function.scalar.math.Round;
 import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvAvg;
+import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvCount;
 import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMax;
 import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMin;
 import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvSum;
@@ -88,6 +89,7 @@ public class EsqlFunctionRegistry extends FunctionRegistry {
             // multivalue functions
             new FunctionDefinition[] {
                 def(MvAvg.class, MvAvg::new, "mv_avg"),
+                def(MvCount.class, MvCount::new, "mv_count"),
                 def(MvMax.class, MvMax::new, "mv_max"),
                 def(MvMin.class, MvMin::new, "mv_min"),
                 def(MvSum.class, MvSum::new, "mv_sum"),

+ 98 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvCount.java

@@ -0,0 +1,98 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue;
+
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.IntArrayVector;
+import org.elasticsearch.compute.data.IntBlock;
+import org.elasticsearch.compute.data.Vector;
+import org.elasticsearch.compute.operator.EvalOperator;
+import org.elasticsearch.xpack.esql.type.EsqlDataTypes;
+import org.elasticsearch.xpack.ql.expression.Expression;
+import org.elasticsearch.xpack.ql.tree.NodeInfo;
+import org.elasticsearch.xpack.ql.tree.Source;
+import org.elasticsearch.xpack.ql.type.DataType;
+import org.elasticsearch.xpack.ql.type.DataTypes;
+
+import java.util.List;
+import java.util.function.Supplier;
+
+import static org.elasticsearch.xpack.ql.expression.TypeResolutions.isType;
+
+/**
+ * Reduce a multivalued field to a single valued field containing the minimum value.
+ */
+public class MvCount extends AbstractMultivalueFunction {
+    public MvCount(Source source, Expression field) {
+        super(source, field);
+    }
+
+    @Override
+    protected TypeResolution resolveFieldType() {
+        return isType(field(), EsqlDataTypes::isRepresentable, sourceText(), null, "representable");
+    }
+
+    @Override
+    public DataType dataType() {
+        return DataTypes.INTEGER;
+    }
+
+    @Override
+    protected Object foldMultivalued(List<?> l) {
+        return l.size();
+    }
+
+    @Override
+    protected Supplier<EvalOperator.ExpressionEvaluator> evaluator(Supplier<EvalOperator.ExpressionEvaluator> fieldEval) {
+        return () -> new Evaluator(fieldEval.get());
+    }
+
+    @Override
+    public Expression replaceChildren(List<Expression> newChildren) {
+        return new MvCount(source(), newChildren.get(0));
+    }
+
+    @Override
+    protected NodeInfo<? extends Expression> info() {
+        return NodeInfo.create(this, MvCount::new, field());
+    }
+
+    private static class Evaluator extends AbstractMultivalueFunction.AbstractEvaluator {
+        protected Evaluator(EvalOperator.ExpressionEvaluator field) {
+            super(field);
+        }
+
+        @Override
+        protected String name() {
+            return "MvCount";
+        }
+
+        @Override
+        protected Block evalNullable(Block fieldVal) {
+            IntBlock.Builder builder = IntBlock.newBlockBuilder(fieldVal.getPositionCount());
+            for (int p = 0; p < fieldVal.getPositionCount(); p++) {
+                int valueCount = fieldVal.getValueCount(p);
+                if (valueCount == 0) {
+                    builder.appendNull();
+                    continue;
+                }
+                builder.appendInt(valueCount);
+            }
+            return builder.build();
+        }
+
+        @Override
+        protected Vector evalNotNullable(Block fieldVal) {
+            int[] values = new int[fieldVal.getPositionCount()];
+            for (int p = 0; p < fieldVal.getPositionCount(); p++) {
+                values[p] = fieldVal.getValueCount(p);
+            }
+            return new IntArrayVector(values, values.length);
+        }
+    }
+}

+ 3 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypes.java

@@ -36,6 +36,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.math.Pow;
 import org.elasticsearch.xpack.esql.expression.function.scalar.math.Round;
 import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.AbstractMultivalueFunction;
 import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvAvg;
+import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvCount;
 import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMax;
 import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMin;
 import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvSum;
@@ -230,6 +231,7 @@ public final class PlanNamedTypes {
             of(AggregateFunction.class, Sum.class, PlanNamedTypes::writeAggFunction, PlanNamedTypes::readAggFunction),
             // Multivalue functions
             of(AbstractMultivalueFunction.class, MvAvg.class, PlanNamedTypes::writeMvFunction, PlanNamedTypes::readMvFunction),
+            of(AbstractMultivalueFunction.class, MvCount.class, PlanNamedTypes::writeMvFunction, PlanNamedTypes::readMvFunction),
             of(AbstractMultivalueFunction.class, MvMax.class, PlanNamedTypes::writeMvFunction, PlanNamedTypes::readMvFunction),
             of(AbstractMultivalueFunction.class, MvMin.class, PlanNamedTypes::writeMvFunction, PlanNamedTypes::readMvFunction),
             of(AbstractMultivalueFunction.class, MvSum.class, PlanNamedTypes::writeMvFunction, PlanNamedTypes::readMvFunction),
@@ -828,6 +830,7 @@ public final class PlanNamedTypes {
     // -- Multivalue functions
     static final Map<String, BiFunction<Source, Expression, AbstractMultivalueFunction>> MV_CTRS = Map.ofEntries(
         entry(name(MvAvg.class), MvAvg::new),
+        entry(name(MvCount.class), MvCount::new),
         entry(name(MvMax.class), MvMax::new),
         entry(name(MvMin.class), MvMin::new),
         entry(name(MvSum.class), MvSum::new)

+ 45 - 0
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvCountTests.java

@@ -0,0 +1,45 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue;
+
+import org.elasticsearch.xpack.ql.expression.Expression;
+import org.elasticsearch.xpack.ql.tree.Source;
+import org.elasticsearch.xpack.ql.type.DataType;
+import org.elasticsearch.xpack.ql.type.DataTypes;
+import org.hamcrest.Matcher;
+
+import java.util.List;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public class MvCountTests extends AbstractMultivalueFunctionTestCase {
+    @Override
+    protected Expression build(Source source, Expression field) {
+        return new MvCount(source, field);
+    }
+
+    @Override
+    protected DataType[] supportedTypes() {
+        return representable();
+    }
+
+    @Override
+    protected DataType expectedType(List<DataType> argTypes) {
+        return DataTypes.INTEGER;
+    }
+
+    @Override
+    protected Matcher<Object> resultMatcherForInput(List<?> input) {
+        return equalTo(input.size());
+    }
+
+    @Override
+    protected String expectedEvaluatorSimpleToString() {
+        return "MvCount[field=Attribute[channel=0]]";
+    }
+}