Browse Source

Add `mv_join` function (ESQL-1166)

Adds an `mv_join` function that joins together multivalue string fields.
You can combine this with out fancy new `to_string` to join together any
multivalued fields into a string.
Nik Everett 2 years ago
parent
commit
64e41ef100

+ 4 - 0
docs/reference/esql/esql-functions.asciidoc

@@ -23,6 +23,7 @@ these functions:
 * <<esql-length>>
 * <<esql-mv_avg>>
 * <<esql-mv_count>>
+* <<esql-mv_join>>
 * <<esql-mv_max>>
 * <<esql-mv_median>>
 * <<esql-mv_min>>
@@ -32,6 +33,7 @@ these functions:
 * <<esql-split>>
 * <<esql-starts_with>>
 * <<esql-substring>>
+* <<esql-to_string>>
 
 include::functions/abs.asciidoc[]
 include::functions/case.asciidoc[]
@@ -46,6 +48,7 @@ include::functions/is_null.asciidoc[]
 include::functions/length.asciidoc[]
 include::functions/mv_avg.asciidoc[]
 include::functions/mv_count.asciidoc[]
+include::functions/mv_join.asciidoc[]
 include::functions/mv_max.asciidoc[]
 include::functions/mv_median.asciidoc[]
 include::functions/mv_min.asciidoc[]
@@ -55,3 +58,4 @@ include::functions/round.asciidoc[]
 include::functions/split.asciidoc[]
 include::functions/starts_with.asciidoc[]
 include::functions/substring.asciidoc[]
+include::functions/to_string.asciidoc[]

+ 30 - 0
docs/reference/esql/functions/mv_join.asciidoc

@@ -0,0 +1,30 @@
+[[esql-mv_join]]
+=== `MV_JOIN`
+Converts a multivalued string field into a single valued field containing the
+concatenation of all values separated by a delimiter:
+
+[source,esql]
+----
+include::{esql-specs}/string.csv-spec[tag=mv_join]
+----
+
+Returns:
+
+[%header,format=dsv,separator=|]
+|===
+include::{esql-specs}/string.csv-spec[tag=mv_join-result]
+|===
+
+If you want to join non-string fields call <<esql-to_string>> on them first:
+[source,esql]
+----
+include::{esql-specs}/ints.csv-spec[tag=mv_join]
+----
+
+Returns:
+
+[%header,format=dsv,separator=|]
+|===
+include::{esql-specs}/ints.csv-spec[tag=mv_join-result]
+|===
+

+ 29 - 0
docs/reference/esql/functions/to_string.asciidoc

@@ -0,0 +1,29 @@
+[[esql-to_string]]
+=== `TO_STRING`
+Converts a field into a string. For example:
+
+[source,esql]
+----
+include::{esql-specs}/ints.csv-spec[tag=to_string]
+----
+
+which returns:
+
+[%header,format=dsv,separator=|]
+|===
+include::{esql-specs}/ints.csv-spec[tag=to_string-result]
+|===
+
+It also works fine on multivalued fields:
+
+[source,esql]
+----
+include::{esql-specs}/ints.csv-spec[tag=to_string_multivalue]
+----
+
+which returns:
+
+[%header,format=dsv,separator=|]
+|===
+include::{esql-specs}/ints.csv-spec[tag=to_string_multivalue-result]
+|===

+ 1 - 1
x-pack/plugin/esql/compute/ann/src/main/java/org/elasticsearch/compute/ann/MvEvaluator.java

@@ -42,7 +42,7 @@ public @interface MvEvaluator {
     String extraName() default "";
 
     /**
-     * Method name called to convert state into
+     * Method called to convert state into result.
      */
     String finish() default "";
 }

+ 39 - 0
x-pack/plugin/esql/qa/testFixtures/src/main/resources/ints.csv-spec

@@ -7,3 +7,42 @@ emp_no:integer |byte:keyword   |short:keyword  |long:keyword |int:keyword |langu
 10001          |2              |2              |2            |2           |2
 10002          |5              |5              |5            |5           |5
 ;
+
+convertToStringSimple
+// tag::to_string[]
+ROW a=10
+| EVAL j = TO_STRING(a)
+// end::to_string[]
+;
+
+// tag::to_string-result[]
+a:integer | j:keyword
+       10 | "10"
+// end::to_string-result[]
+;
+
+convertToStringMultivalue
+// tag::to_string_multivalue[]
+ROW a=[10, 9, 8]
+| EVAL j = TO_STRING(a)
+// end::to_string_multivalue[]
+;
+
+// tag::to_string_multivalue-result[]
+ a:integer | j:keyword
+[10, 9, 8] | ["10", "9", "8"]
+// end::to_string_multivalue-result[]
+;
+
+mvJoin
+// tag::mv_join[]
+ROW a=[10, 9, 8]
+| EVAL j = MV_JOIN(TO_STRING(a), ", ")
+// end::mv_join[]
+;
+
+// tag::mv_join-result[]
+ a:integer | j:keyword
+[10, 9, 8] | "10, 9, 8"
+// end::mv_join-result[]
+;

+ 1 - 0
x-pack/plugin/esql/qa/testFixtures/src/main/resources/show.csv-spec

@@ -30,6 +30,7 @@ median_absolute_deviation|median_absolute_deviation(arg1)
 min                      |min(arg1)
 mv_avg                   |mv_avg(arg1)
 mv_count                 |mv_count(arg1)
+mv_join                  |mv_join(arg1, arg2)
 mv_max                   |mv_max(arg1)
 mv_median                |mv_median(arg1)
 mv_min                   |mv_min(arg1)

+ 13 - 0
x-pack/plugin/esql/qa/testFixtures/src/main/resources/string.csv-spec

@@ -232,6 +232,19 @@ ROW a=["foo", "zoo", "bar"]
 // end::mv_count-result[]
 ;
 
+mvJoin
+// tag::mv_join[]
+ROW a=["foo", "zoo", "bar"]
+| EVAL j = MV_JOIN(a, ", ")
+// end::mv_join[]
+;
+
+// tag::mv_join-result[]
+            a:keyword | j:keyword
+["foo", "zoo", "bar"] | "foo, zoo, bar"
+// end::mv_join-result[]
+;
+
 mvMax
 // tag::mv_max[]
 ROW a=["foo", "zoo", "bar"]

+ 2 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java

@@ -30,6 +30,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.math.Pow;
 import org.elasticsearch.xpack.esql.expression.function.scalar.math.Round;
 import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvAvg;
 import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvCount;
+import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvJoin;
 import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMax;
 import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMedian;
 import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMin;
@@ -96,6 +97,7 @@ public class EsqlFunctionRegistry extends FunctionRegistry {
             new FunctionDefinition[] {
                 def(MvAvg.class, MvAvg::new, "mv_avg"),
                 def(MvCount.class, MvCount::new, "mv_count"),
+                def(MvJoin.class, MvJoin::new, "mv_join"),
                 def(MvMax.class, MvMax::new, "mv_max"),
                 def(MvMedian.class, MvMedian::new, "mv_median"),
                 def(MvMin.class, MvMin::new, "mv_min"),

+ 0 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvAvg.java

@@ -20,7 +20,6 @@ import org.elasticsearch.xpack.ql.type.DataTypes;
 import java.util.List;
 import java.util.function.Supplier;
 
-import static org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvSum.sum;
 import static org.elasticsearch.xpack.esql.type.EsqlDataTypes.isRepresentable;
 import static org.elasticsearch.xpack.ql.expression.TypeResolutions.isType;
 

+ 148 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvJoin.java

@@ -0,0 +1,148 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue;
+
+import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.BytesRefBuilder;
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.BytesRefBlock;
+import org.elasticsearch.compute.data.Page;
+import org.elasticsearch.compute.operator.EvalOperator;
+import org.elasticsearch.xpack.esql.planner.Mappable;
+import org.elasticsearch.xpack.ql.expression.Expression;
+import org.elasticsearch.xpack.ql.expression.TypeResolutions;
+import org.elasticsearch.xpack.ql.expression.function.scalar.BinaryScalarFunction;
+import org.elasticsearch.xpack.ql.tree.NodeInfo;
+import org.elasticsearch.xpack.ql.tree.Source;
+import org.elasticsearch.xpack.ql.type.DataType;
+import org.elasticsearch.xpack.ql.type.DataTypes;
+
+import java.util.function.Function;
+import java.util.function.Supplier;
+
+import static org.elasticsearch.xpack.ql.expression.TypeResolutions.isString;
+
+/**
+ * Reduce a multivalued string field to a single valued field by concatenating all values.
+ */
+public class MvJoin extends BinaryScalarFunction implements Mappable {
+    public MvJoin(Source source, Expression field, Expression delim) {
+        super(source, field, delim);
+    }
+
+    @Override
+    protected TypeResolution resolveType() {
+        if (childrenResolved() == false) {
+            return new TypeResolution("Unresolved children");
+        }
+
+        TypeResolution resolution = isString(left(), sourceText(), TypeResolutions.ParamOrdinal.FIRST);
+        if (resolution.unresolved()) {
+            return resolution;
+        }
+
+        return isString(right(), sourceText(), TypeResolutions.ParamOrdinal.SECOND);
+    }
+
+    @Override
+    public DataType dataType() {
+        return DataTypes.KEYWORD;
+    }
+
+    @Override
+    public Supplier<EvalOperator.ExpressionEvaluator> toEvaluator(
+        Function<Expression, Supplier<EvalOperator.ExpressionEvaluator>> toEvaluator
+    ) {
+        Supplier<EvalOperator.ExpressionEvaluator> fieldEval = toEvaluator.apply(left());
+        Supplier<EvalOperator.ExpressionEvaluator> delimEval = toEvaluator.apply(right());
+        return () -> new MvJoinEvaluator(fieldEval.get(), delimEval.get());
+    }
+
+    @Override
+    public Object fold() {
+        return Mappable.super.fold();
+    }
+
+    @Override
+    protected BinaryScalarFunction replaceChildren(Expression newLeft, Expression newRight) {
+        return new MvJoin(source(), newLeft, newRight);
+    }
+
+    @Override
+    protected NodeInfo<? extends Expression> info() {
+        return NodeInfo.create(this, MvJoin::new, left(), right());
+    }
+
+    /**
+     * Evaluator for {@link MvJoin}. Not generated and doesn't extend from
+     * {@link AbstractMultivalueFunction.AbstractEvaluator} because it's just
+     * too different from all the other mv operators:
+     * <ul>
+     *     <li>It takes an extra parameter - the delimiter</li>
+     *     <li>That extra parameter makes it much more likely to be {@code null}</li>
+     *     <li>The actual joining process needs init step per row - {@link BytesRefBuilder#clear()}</li>
+     * </ul>
+     */
+    private class MvJoinEvaluator implements EvalOperator.ExpressionEvaluator {
+        private final EvalOperator.ExpressionEvaluator field;
+        private final EvalOperator.ExpressionEvaluator delim;
+
+        MvJoinEvaluator(EvalOperator.ExpressionEvaluator field, EvalOperator.ExpressionEvaluator delim) {
+            this.field = field;
+            this.delim = delim;
+        }
+
+        @Override
+        public final Block eval(Page page) {
+            Block fieldUncast = field.eval(page);
+            Block delimUncast = delim.eval(page);
+            if (fieldUncast.areAllValuesNull() || delimUncast.areAllValuesNull()) {
+                return Block.constantNullBlock(page.getPositionCount());
+            }
+            BytesRefBlock fieldVal = (BytesRefBlock) fieldUncast;
+            BytesRefBlock delimVal = (BytesRefBlock) delimUncast;
+
+            int positionCount = page.getPositionCount();
+            BytesRefBlock.Builder builder = BytesRefBlock.newBlockBuilder(positionCount);
+            BytesRefBuilder work = new BytesRefBuilder();
+            BytesRef fieldScratch = new BytesRef();
+            BytesRef delimScratch = new BytesRef();
+            for (int p = 0; p < positionCount; p++) {
+                int fieldValueCount = fieldVal.getValueCount(p);
+                if (fieldValueCount == 0) {
+                    builder.appendNull();
+                    continue;
+                }
+                if (delimVal.getValueCount(p) != 1) {
+                    builder.appendNull();
+                    continue;
+                }
+                int first = fieldVal.getFirstValueIndex(p);
+                if (fieldValueCount == 1) {
+                    builder.appendBytesRef(fieldVal.getBytesRef(first, fieldScratch));
+                    continue;
+                }
+                int end = first + fieldValueCount;
+                BytesRef delim = delimVal.getBytesRef(delimVal.getFirstValueIndex(p), delimScratch);
+                work.clear();
+                work.append(fieldVal.getBytesRef(first, fieldScratch));
+                for (int i = first + 1; i < end; i++) {
+                    work.append(delim);
+                    work.append(fieldVal.getBytesRef(i, fieldScratch));
+                }
+                builder.appendBytesRef(work.get());
+            }
+            return builder.build();
+        }
+
+        @Override
+        public final String toString() {
+            return "MvJoin[field=" + field + ", delim=" + delim + "]";
+        }
+    }
+}

+ 0 - 7
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSum.java

@@ -17,7 +17,6 @@ import org.elasticsearch.xpack.ql.tree.Source;
 
 import java.util.List;
 import java.util.function.Supplier;
-import java.util.stream.DoubleStream;
 
 import static org.elasticsearch.xpack.esql.type.EsqlDataTypes.isRepresentable;
 import static org.elasticsearch.xpack.ql.expression.TypeResolutions.isType;
@@ -35,12 +34,6 @@ public class MvSum extends AbstractMultivalueFunction {
         return isType(field(), t -> t.isNumeric() && isRepresentable(t), sourceText(), null, "numeric");
     }
 
-    static double sum(DoubleStream stream) {
-        CompensatedSum sum = new CompensatedSum();
-        stream.forEach(sum::add);
-        return sum.value();
-    }
-
     @Override
     protected Supplier<EvalOperator.ExpressionEvaluator> evaluator(Supplier<EvalOperator.ExpressionEvaluator> fieldEval) {
         return switch (LocalExecutionPlanner.toElementType(field().dataType())) {

+ 17 - 6
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypes.java

@@ -39,6 +39,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.math.Round;
 import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.AbstractMultivalueFunction;
 import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvAvg;
 import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvCount;
+import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvJoin;
 import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMax;
 import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMedian;
 import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMin;
@@ -240,12 +241,13 @@ public final class PlanNamedTypes {
             of(AggregateFunction.class, MedianAbsoluteDeviation.class, PlanNamedTypes::writeAggFunction, PlanNamedTypes::readAggFunction),
             of(AggregateFunction.class, Sum.class, PlanNamedTypes::writeAggFunction, PlanNamedTypes::readAggFunction),
             // Multivalue functions
-            of(AbstractMultivalueFunction.class, MvAvg.class, PlanNamedTypes::writeMvFunction, PlanNamedTypes::readMvFunction),
-            of(AbstractMultivalueFunction.class, MvCount.class, PlanNamedTypes::writeMvFunction, PlanNamedTypes::readMvFunction),
-            of(AbstractMultivalueFunction.class, MvMax.class, PlanNamedTypes::writeMvFunction, PlanNamedTypes::readMvFunction),
-            of(AbstractMultivalueFunction.class, MvMedian.class, PlanNamedTypes::writeMvFunction, PlanNamedTypes::readMvFunction),
-            of(AbstractMultivalueFunction.class, MvMin.class, PlanNamedTypes::writeMvFunction, PlanNamedTypes::readMvFunction),
-            of(AbstractMultivalueFunction.class, MvSum.class, PlanNamedTypes::writeMvFunction, PlanNamedTypes::readMvFunction),
+            of(ScalarFunction.class, MvAvg.class, PlanNamedTypes::writeMvFunction, PlanNamedTypes::readMvFunction),
+            of(ScalarFunction.class, MvCount.class, PlanNamedTypes::writeMvFunction, PlanNamedTypes::readMvFunction),
+            of(ScalarFunction.class, MvJoin.class, PlanNamedTypes::writeMvJoin, PlanNamedTypes::readMvJoin),
+            of(ScalarFunction.class, MvMax.class, PlanNamedTypes::writeMvFunction, PlanNamedTypes::readMvFunction),
+            of(ScalarFunction.class, MvMedian.class, PlanNamedTypes::writeMvFunction, PlanNamedTypes::readMvFunction),
+            of(ScalarFunction.class, MvMin.class, PlanNamedTypes::writeMvFunction, PlanNamedTypes::readMvFunction),
+            of(ScalarFunction.class, MvSum.class, PlanNamedTypes::writeMvFunction, PlanNamedTypes::readMvFunction),
             // Expressions (other)
             of(Expression.class, Literal.class, PlanNamedTypes::writeLiteral, PlanNamedTypes::readLiteral),
             of(Expression.class, Order.class, PlanNamedTypes::writeOrder, PlanNamedTypes::readOrder)
@@ -898,6 +900,15 @@ public final class PlanNamedTypes {
         out.writeExpression(fn.field());
     }
 
+    static MvJoin readMvJoin(PlanStreamInput in) throws IOException {
+        return new MvJoin(Source.EMPTY, in.readExpression(), in.readExpression());
+    }
+
+    static void writeMvJoin(PlanStreamOutput out, MvJoin fn) throws IOException {
+        out.writeExpression(fn.left());
+        out.writeExpression(fn.right());
+    }
+
     // -- NamedExpressions
 
     static Alias readAlias(PlanStreamInput in) throws IOException {

+ 0 - 2
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/AbstractMultivalueFunctionTestCase.java

@@ -70,8 +70,6 @@ public abstract class AbstractMultivalueFunctionTestCase extends AbstractScalarF
         return build(source, args.get(0));
     }
 
-    // TODO once we have explicit array types we should assert that non-arrays are noops
-
     @Override
     protected final Expression constantFoldable(List<Object> data) {
         return build(Source.EMPTY, new Literal(Source.EMPTY, data.get(0), DataTypes.fromJava(((List<?>) data.get(0)).get(0))));

+ 98 - 0
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvJoinTests.java

@@ -0,0 +1,98 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue;
+
+import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.xpack.esql.expression.function.scalar.AbstractScalarFunctionTestCase;
+import org.elasticsearch.xpack.ql.expression.Expression;
+import org.elasticsearch.xpack.ql.expression.Literal;
+import org.elasticsearch.xpack.ql.tree.Source;
+import org.elasticsearch.xpack.ql.type.DataType;
+import org.elasticsearch.xpack.ql.type.DataTypes;
+import org.hamcrest.Matcher;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.Collectors;
+
+import static org.elasticsearch.compute.data.BlockUtils.toJavaObject;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.nullValue;
+
+public class MvJoinTests extends AbstractScalarFunctionTestCase {
+    @Override
+    protected Expression build(Source source, List<Literal> args) {
+        return new MvJoin(source, args.get(0), args.get(1));
+    }
+
+    @Override
+    protected List<Object> simpleData() {
+        return List.of(List.of(new BytesRef("foo"), new BytesRef("bar"), new BytesRef("baz")), new BytesRef(", "));
+    }
+
+    @Override
+    protected Expression expressionForSimpleData() {
+        return new MvJoin(Source.EMPTY, field("field", DataTypes.KEYWORD), field("delim", DataTypes.KEYWORD));
+    }
+
+    @Override
+    protected Matcher<Object> resultMatcher(List<Object> data) {
+        List<?> field = (List<?>) data.get(0);
+        BytesRef delim = (BytesRef) data.get(1);
+        if (field == null || delim == null) {
+            return nullValue();
+        }
+        return equalTo(
+            new BytesRef(field.stream().map(v -> ((BytesRef) v).utf8ToString()).collect(Collectors.joining(delim.utf8ToString())))
+        );
+    }
+
+    @Override
+    protected String expectedEvaluatorSimpleToString() {
+        return "MvJoin[field=Attribute[channel=0], delim=Attribute[channel=1]]";
+    }
+
+    @Override
+    protected Expression constantFoldable(List<Object> data) {
+        return new MvJoin(
+            Source.EMPTY,
+            new Literal(Source.EMPTY, data.get(0), DataTypes.KEYWORD),
+            new Literal(Source.EMPTY, data.get(1), DataTypes.KEYWORD)
+        );
+    }
+
+    @Override
+    protected List<ArgumentSpec> argSpec() {
+        return List.of(required(DataTypes.KEYWORD), required(DataTypes.KEYWORD));
+    }
+
+    @Override
+    protected DataType expectedType(List<DataType> argTypes) {
+        return DataTypes.KEYWORD;
+    }
+
+    public void testNull() {
+        BytesRef foo = new BytesRef("foo");
+        BytesRef bar = new BytesRef("bar");
+        BytesRef delim = new BytesRef(";");
+        Expression expression = expressionForSimpleData();
+
+        assertThat(toJavaObject(evaluator(expression).get().eval(row(Arrays.asList(Arrays.asList(foo, bar), null))), 0), nullValue());
+        assertThat(toJavaObject(evaluator(expression).get().eval(row(Arrays.asList(foo, null))), 0), nullValue());
+        assertThat(toJavaObject(evaluator(expression).get().eval(row(Arrays.asList(null, null))), 0), nullValue());
+
+        assertThat(
+            toJavaObject(evaluator(expression).get().eval(row(Arrays.asList(Arrays.asList(foo, bar), Arrays.asList(delim, bar)))), 0),
+            nullValue()
+        );
+        assertThat(toJavaObject(evaluator(expression).get().eval(row(Arrays.asList(foo, Arrays.asList(delim, bar)))), 0), nullValue());
+        assertThat(toJavaObject(evaluator(expression).get().eval(row(Arrays.asList(null, Arrays.asList(delim, bar)))), 0), nullValue());
+
+        assertThat(toJavaObject(evaluator(expression).get().eval(row(Arrays.asList(null, delim))), 0), nullValue());
+    }
+}