Browse Source

Clean `UnaryScalarFunction` slightly (ESQL-1359)

Our base class for `UnaryScalarFunction` only takes one argument because
it's, well, unary. But it was reporting type errors on that argument as
though it were the first of many. That's silly.

I also added some tests for the `Abs` function which extends our
`UnaryScalarFunction` that would have caught this error. While I was
there I ported `Length` from QL's `UnaryScalarFunction` to ours. Let's
use our stuff. Even if it's wrong we can change it without bothing QL.

Finally I added some javadocs and removed some unused code.
Nik Everett 2 years ago
parent
commit
cbd4992e85

+ 2 - 2
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/UnaryScalarFunction.java

@@ -8,6 +8,7 @@
 package org.elasticsearch.xpack.esql.expression.function.scalar;
 
 import org.elasticsearch.xpack.ql.expression.Expression;
+import org.elasticsearch.xpack.ql.expression.TypeResolutions;
 import org.elasticsearch.xpack.ql.expression.function.scalar.ScalarFunction;
 import org.elasticsearch.xpack.ql.expression.gen.script.ScriptTemplate;
 import org.elasticsearch.xpack.ql.tree.Source;
@@ -16,7 +17,6 @@ import org.elasticsearch.xpack.ql.type.DataType;
 import java.util.Arrays;
 import java.util.Objects;
 
-import static org.elasticsearch.xpack.ql.expression.TypeResolutions.ParamOrdinal.FIRST;
 import static org.elasticsearch.xpack.ql.expression.TypeResolutions.isNumeric;
 
 public abstract class UnaryScalarFunction extends ScalarFunction {
@@ -33,7 +33,7 @@ public abstract class UnaryScalarFunction extends ScalarFunction {
             return new Expression.TypeResolution("Unresolved children");
         }
 
-        return isNumeric(field, sourceText(), FIRST);
+        return isNumeric(field, sourceText(), TypeResolutions.ParamOrdinal.DEFAULT);
     }
 
     @Override

+ 4 - 9
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Length.java

@@ -11,15 +11,15 @@ import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.UnicodeUtil;
 import org.elasticsearch.compute.ann.Evaluator;
 import org.elasticsearch.compute.operator.EvalOperator;
+import org.elasticsearch.xpack.esql.expression.function.scalar.UnaryScalarFunction;
 import org.elasticsearch.xpack.esql.planner.Mappable;
 import org.elasticsearch.xpack.ql.expression.Expression;
-import org.elasticsearch.xpack.ql.expression.function.scalar.UnaryScalarFunction;
-import org.elasticsearch.xpack.ql.expression.gen.processor.Processor;
 import org.elasticsearch.xpack.ql.tree.NodeInfo;
 import org.elasticsearch.xpack.ql.tree.Source;
 import org.elasticsearch.xpack.ql.type.DataType;
 import org.elasticsearch.xpack.ql.type.DataTypes;
 
+import java.util.List;
 import java.util.function.Function;
 import java.util.function.Supplier;
 
@@ -62,13 +62,8 @@ public class Length extends UnaryScalarFunction implements Mappable {
     }
 
     @Override
-    protected UnaryScalarFunction replaceChild(Expression newChild) {
-        return new Length(source(), newChild);
-    }
-
-    @Override
-    protected Processor makeProcessor() {
-        throw new UnsupportedOperationException();
+    public Expression replaceChildren(List<Expression> newChildren) {
+        return new Length(source(), newChildren.get(0));
     }
 
     @Override

+ 2 - 2
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypes.java

@@ -266,7 +266,7 @@ public final class PlanNamedTypes {
             // UnaryScalarFunction
             of(QL_UNARY_SCLR_CLS, IsNotNull.class, PlanNamedTypes::writeQLUnaryScalar, PlanNamedTypes::readQLUnaryScalar),
             of(QL_UNARY_SCLR_CLS, IsNull.class, PlanNamedTypes::writeQLUnaryScalar, PlanNamedTypes::readQLUnaryScalar),
-            of(QL_UNARY_SCLR_CLS, Length.class, PlanNamedTypes::writeQLUnaryScalar, PlanNamedTypes::readQLUnaryScalar),
+            of(ESQL_UNARY_SCLR_CLS, Length.class, PlanNamedTypes::writeESQLUnaryScalar, PlanNamedTypes::readESQLUnaryScalar),
             of(QL_UNARY_SCLR_CLS, Not.class, PlanNamedTypes::writeQLUnaryScalar, PlanNamedTypes::readQLUnaryScalar),
             of(ESQL_UNARY_SCLR_CLS, Abs.class, PlanNamedTypes::writeESQLUnaryScalar, PlanNamedTypes::readESQLUnaryScalar),
             of(ScalarFunction.class, E.class, PlanNamedTypes::writeNoArgScalar, PlanNamedTypes::readNoArgScalar),
@@ -943,6 +943,7 @@ public final class PlanNamedTypes {
         entry(name(IsFinite.class), IsFinite::new),
         entry(name(IsInfinite.class), IsInfinite::new),
         entry(name(IsNaN.class), IsNaN::new),
+        entry(name(Length.class), Length::new),
         entry(name(Metadata.class), Metadata::new),
         entry(name(ToBoolean.class), ToBoolean::new),
         entry(name(ToDatetime.class), ToDatetime::new),
@@ -989,7 +990,6 @@ public final class PlanNamedTypes {
             Map.ofEntries(
                 entry(name(IsNotNull.class), IsNotNull::new),
                 entry(name(IsNull.class), IsNull::new),
-                entry(name(Length.class), Length::new),
                 entry(name(Not.class), Not::new)
             );
 

+ 25 - 1
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/AbstractScalarFunctionTestCase.java

@@ -32,15 +32,27 @@ import static org.hamcrest.Matchers.equalTo;
  * Base class for function tests.
  */
 public abstract class AbstractScalarFunctionTestCase extends AbstractFunctionTestCase {
-
+    /**
+     * Describe supported arguments. Build each argument with
+     * {@link #required} or {@link #optional}.
+     */
     protected abstract List<ArgumentSpec> argSpec();
 
+    /**
+     * The data type that applying this function to arguments of this type should produce.
+     */
     protected abstract DataType expectedType(List<DataType> argTypes);
 
+    /**
+     * Define a required argument.
+     */
     protected final ArgumentSpec required(DataType... validTypes) {
         return new ArgumentSpec(false, withNullAndSorted(validTypes));
     }
 
+    /**
+     * Define an optional argument.
+     */
     protected final ArgumentSpec optional(DataType... validTypes) {
         return new ArgumentSpec(true, withNullAndSorted(validTypes));
     }
@@ -52,18 +64,30 @@ public abstract class AbstractScalarFunctionTestCase extends AbstractFunctionTes
         return realValidTypes;
     }
 
+    /**
+     * All string types (keyword, text, match_only_text, etc). For passing to {@link #required} or {@link #optional}.
+     */
     protected final DataType[] strings() {
         return EsqlDataTypes.types().stream().filter(DataTypes::isString).toArray(DataType[]::new);
     }
 
+    /**
+     * All integer types (long, int, short, byte). For passing to {@link #required} or {@link #optional}.
+     */
     protected final DataType[] integers() {
         return EsqlDataTypes.types().stream().filter(DataType::isInteger).toArray(DataType[]::new);
     }
 
+    /**
+     * All rational types (double, float, whatever). For passing to {@link #required} or {@link #optional}.
+     */
     protected final DataType[] rationals() {
         return EsqlDataTypes.types().stream().filter(DataType::isRational).toArray(DataType[]::new);
     }
 
+    /**
+     * All numeric types (integers and rationals.) For passing to {@link #required} or {@link #optional}.
+     */
     protected final DataType[] numerics() {
         return EsqlDataTypes.types().stream().filter(DataType::isNumeric).toArray(DataType[]::new);
     }

+ 104 - 0
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/AbsTests.java

@@ -0,0 +1,104 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.expression.function.scalar.math;
+
+import org.elasticsearch.xpack.esql.expression.function.scalar.AbstractScalarFunctionTestCase;
+import org.elasticsearch.xpack.ql.expression.Expression;
+import org.elasticsearch.xpack.ql.expression.Literal;
+import org.elasticsearch.xpack.ql.tree.Source;
+import org.elasticsearch.xpack.ql.type.DataType;
+import org.elasticsearch.xpack.ql.type.DataTypes;
+import org.hamcrest.Matcher;
+
+import java.util.List;
+
+import static org.elasticsearch.compute.data.BlockUtils.toJavaObject;
+import static org.hamcrest.Matchers.equalTo;
+
+public class AbsTests extends AbstractScalarFunctionTestCase {
+    @Override
+    protected List<Object> simpleData() {
+        return List.of(randomInt());
+    }
+
+    @Override
+    protected Expression expressionForSimpleData() {
+        return new Abs(Source.EMPTY, field("arg", DataTypes.INTEGER));
+    }
+
+    @Override
+    protected Matcher<Object> resultMatcher(List<Object> data, DataType dataType) {
+        Object in = data.get(0);
+        if (dataType == DataTypes.INTEGER) {
+            return equalTo(Math.abs(((Integer) in).intValue()));
+        }
+        if (dataType == DataTypes.LONG) {
+            return equalTo(Math.abs(((Long) in).longValue()));
+        }
+        if (dataType == DataTypes.UNSIGNED_LONG) {
+            return equalTo(in);
+        }
+        if (dataType == DataTypes.DOUBLE) {
+            return equalTo(Math.abs(((Double) in).doubleValue()));
+        }
+        throw new IllegalArgumentException("can't match " + in);
+    }
+
+    @Override
+    protected String expectedEvaluatorSimpleToString() {
+        return "AbsIntEvaluator[fieldVal=Attribute[channel=0]]";
+    }
+
+    @Override
+    protected Expression constantFoldable(List<Object> data) {
+        return new Abs(Source.EMPTY, new Literal(Source.EMPTY, data.get(0), DataTypes.INTEGER));
+    }
+
+    @Override
+    protected Expression build(Source source, List<Literal> args) {
+        return new Abs(source, args.get(0));
+    }
+
+    @Override
+    protected List<ArgumentSpec> argSpec() {
+        return List.of(required(numerics()));
+    }
+
+    @Override
+    protected DataType expectedType(List<DataType> argTypes) {
+        return argTypes.get(0);
+    }
+
+    public final void testLong() {
+        List<Object> data = List.of(randomLong());
+        Expression expression = new Abs(Source.EMPTY, field("arg", DataTypes.LONG));
+        Object result = toJavaObject(evaluator(expression).get().eval(row(data)), 0);
+        assertThat(result, resultMatcher(data, DataTypes.LONG));
+    }
+
+    public final void testUnsignedLong() {
+        List<Object> data = List.of(randomLong());
+        Expression expression = new Abs(Source.EMPTY, field("arg", DataTypes.UNSIGNED_LONG));
+        Object result = toJavaObject(evaluator(expression).get().eval(row(data)), 0);
+        assertThat(result, resultMatcher(data, DataTypes.UNSIGNED_LONG));
+    }
+
+    public final void testInt() {
+        List<Object> data = List.of(randomInt());
+        Expression expression = new Abs(Source.EMPTY, field("arg", DataTypes.INTEGER));
+        Object result = toJavaObject(evaluator(expression).get().eval(row(data)), 0);
+        assertThat(result, resultMatcher(data, DataTypes.INTEGER));
+    }
+
+    public final void testDouble() {
+        List<Object> data = List.of(randomDouble());
+        Expression expression = new Abs(Source.EMPTY, field("arg", DataTypes.DOUBLE));
+        Object result = toJavaObject(evaluator(expression).get().eval(row(data)), 0);
+        assertThat(result, resultMatcher(data, DataTypes.DOUBLE));
+    }
+}

+ 0 - 1
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/PowTests.java

@@ -123,7 +123,6 @@ public class PowTests extends AbstractScalarFunctionTestCase {
 
     @Override
     protected List<ArgumentSpec> argSpec() {
-        var validDataTypes = new DataType[] { DataTypes.DOUBLE, DataTypes.LONG, DataTypes.INTEGER };
         return List.of(required(numerics()), required(numerics()));
     }