Browse Source

Move ESQL's LOCATE test cases to cases (#107271)

This moves the test cases declared in the tests for ESQL's LOCATE
function to test cases which will cause #106782 to properly generate all
of the available signatures. It also buys us all of testing for
incorrect parameter combinations.
Nik Everett 1 year ago
parent
commit
8852566489

+ 2 - 2
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Locate.java

@@ -28,8 +28,8 @@ import java.util.function.Function;
 import static org.elasticsearch.xpack.ql.expression.TypeResolutions.ParamOrdinal.FIRST;
 import static org.elasticsearch.xpack.ql.expression.TypeResolutions.ParamOrdinal.SECOND;
 import static org.elasticsearch.xpack.ql.expression.TypeResolutions.ParamOrdinal.THIRD;
-import static org.elasticsearch.xpack.ql.expression.TypeResolutions.isInteger;
 import static org.elasticsearch.xpack.ql.expression.TypeResolutions.isString;
+import static org.elasticsearch.xpack.ql.expression.TypeResolutions.isType;
 
 /**
  * Locate function, given a string 'a' and a substring 'b', it returns the index of the first occurrence of the substring 'b' in 'a'.
@@ -80,7 +80,7 @@ public class Locate extends EsqlScalarFunction implements OptionalArgument {
             return resolution;
         }
 
-        return start == null ? TypeResolution.TYPE_RESOLVED : isInteger(start, sourceText(), THIRD);
+        return start == null ? TypeResolution.TYPE_RESOLVED : isType(start, dt -> dt == DataTypes.INTEGER, sourceText(), THIRD, "integer");
     }
 
     @Override

+ 1 - 1
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java

@@ -72,7 +72,7 @@ public record TestCaseSupplier(String name, List<DataType> types, Supplier<TestC
         this(nameFromTypes(types), types, supplier);
     }
 
-    static String nameFromTypes(List<DataType> types) {
+    public static String nameFromTypes(List<DataType> types) {
         return types.stream().map(t -> "<" + t.typeName() + ">").collect(Collectors.joining(", "));
     }
 

+ 149 - 167
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/LocateTests.java

@@ -11,22 +11,21 @@ import com.carrotsearch.randomizedtesting.annotations.Name;
 import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
 
 import org.apache.lucene.util.BytesRef;
-import org.elasticsearch.compute.data.Block;
-import org.elasticsearch.compute.operator.EvalOperator;
+import org.elasticsearch.core.Nullable;
 import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase;
 import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
 import org.elasticsearch.xpack.ql.expression.Expression;
-import org.elasticsearch.xpack.ql.expression.Literal;
 import org.elasticsearch.xpack.ql.tree.Source;
 import org.elasticsearch.xpack.ql.type.DataType;
 import org.elasticsearch.xpack.ql.type.DataTypes;
 
+import java.nio.charset.StandardCharsets;
 import java.util.ArrayList;
 import java.util.List;
+import java.util.Locale;
+import java.util.function.Function;
 import java.util.function.Supplier;
 
-import static java.nio.charset.StandardCharsets.UTF_8;
-import static org.elasticsearch.compute.data.BlockUtils.toJavaObject;
 import static org.hamcrest.Matchers.equalTo;
 
 /**
@@ -37,192 +36,175 @@ public class LocateTests extends AbstractFunctionTestCase {
         this.testCase = testCaseSupplier.get();
     }
 
+    private static final DataType[] STRING_TYPES = new DataType[] { DataTypes.KEYWORD, DataTypes.TEXT };
+
     @ParametersFactory
     public static Iterable<Object[]> parameters() {
         List<TestCaseSupplier> suppliers = new ArrayList<>();
-        suppliers.add(
-            supplier(
-                "keywords",
-                DataTypes.KEYWORD,
-                DataTypes.KEYWORD,
-                () -> randomRealisticUnicodeOfCodepointLength(10),
-                () -> randomRealisticUnicodeOfCodepointLength(2),
-                () -> 0
-            )
-        );
-        suppliers.add(
-            supplier(
-                "mixed keyword, text",
-                DataTypes.KEYWORD,
-                DataTypes.TEXT,
-                () -> randomRealisticUnicodeOfCodepointLength(10),
-                () -> randomRealisticUnicodeOfCodepointLength(2),
-                () -> 0
-            )
-        );
-        suppliers.add(
-            supplier(
-                "texts",
-                DataTypes.TEXT,
-                DataTypes.TEXT,
-                () -> randomRealisticUnicodeOfCodepointLength(10),
-                () -> randomRealisticUnicodeOfCodepointLength(2),
-                () -> 0
-            )
-        );
-        suppliers.add(
-            supplier(
-                "mixed text, keyword",
-                DataTypes.TEXT,
-                DataTypes.KEYWORD,
-                () -> randomRealisticUnicodeOfCodepointLength(10),
-                () -> randomRealisticUnicodeOfCodepointLength(2),
-                () -> 0
-            )
-        );
-        return parameterSuppliersFromTypedData(errorsForCasesWithoutExamples(anyNullIsNull(true, suppliers)));
-    }
-
-    public void testToString() {
-        assertThat(
-            evaluator(
-                new Locate(
-                    Source.EMPTY,
-                    field("str", DataTypes.KEYWORD),
-                    field("substr", DataTypes.KEYWORD),
-                    field("start", DataTypes.INTEGER)
-                )
-            ).get(driverContext()).toString(),
-            equalTo("LocateEvaluator[str=Attribute[channel=0], substr=Attribute[channel=1], start=Attribute[channel=2]]")
-        );
-    }
-
-    @Override
-    protected Expression build(Source source, List<Expression> args) {
-        return new Locate(source, args.get(0), args.get(1), args.size() < 3 ? null : args.get(2));
-    }
-
-    public void testPrefixString() {
-        assertThat(process("a tiger", "a t", 0), equalTo(1));
-        assertThat(process("a tiger", "a", 0), equalTo(1));
-        assertThat(process("界世", "界", 0), equalTo(1));
-    }
-
-    public void testSuffixString() {
-        assertThat(process("a tiger", "er", 0), equalTo(6));
-        assertThat(process("a tiger", "r", 0), equalTo(7));
-        assertThat(process("世界", "界", 0), equalTo(2));
-    }
-
-    public void testMidString() {
-        assertThat(process("a tiger", "ti", 0), equalTo(3));
-        assertThat(process("a tiger", "ige", 0), equalTo(4));
-        assertThat(process("世界世", "界", 0), equalTo(2));
-    }
-
-    public void testOutOfRange() {
-        assertThat(process("a tiger", "tigers", 0), equalTo(0));
-        assertThat(process("a tiger", "ipa", 0), equalTo(0));
-        assertThat(process("世界世", "\uD83C\uDF0D", 0), equalTo(0));
-    }
-
-    public void testExactString() {
-        assertThat(process("a tiger", "a tiger", 0), equalTo(1));
-        assertThat(process("tigers", "tigers", 0), equalTo(1));
-        assertThat(process("界世", "界世", 0), equalTo(1));
-    }
+        for (DataType strType : STRING_TYPES) {
+            for (DataType substrType : STRING_TYPES) {
+                suppliers.add(
+                    supplier(
+                        "",
+                        strType,
+                        substrType,
+                        () -> randomRealisticUnicodeOfCodepointLength(10),
+                        str -> randomRealisticUnicodeOfCodepointLength(2),
+                        null,
+                        (str, substr, start) -> 1 + str.indexOf(substr)
+                    )
+                );
+                suppliers.add(
+                    supplier(
+                        "exact match ",
+                        strType,
+                        substrType,
+                        () -> randomRealisticUnicodeOfCodepointLength(10),
+                        str -> str,
+                        null,
+                        (str, substr, start) -> 1
+                    )
+                );
+                suppliers.add(
+                    supplier(
+                        "",
+                        strType,
+                        substrType,
+                        () -> randomRealisticUnicodeOfCodepointLength(10),
+                        str -> randomRealisticUnicodeOfCodepointLength(2),
+                        () -> between(0, 3),
+                        (str, substr, start) -> 1 + str.indexOf(substr, start)
+                    )
+                );
+            }
+        }
 
-    public void testSupplementaryCharacter() {
+        suppliers = errorsForCasesWithoutExamples(anyNullIsNull(true, suppliers));
+
+        // Here follows some non-randomized examples that we want to cover on every run
+        suppliers.add(supplier("a tiger", "a t", null, 1));
+        suppliers.add(supplier("a tiger", "a", null, 1));
+        suppliers.add(supplier("界世", "界", null, 1));
+        suppliers.add(supplier("a tiger", "er", null, 6));
+        suppliers.add(supplier("a tiger", "r", null, 7));
+        suppliers.add(supplier("界世", "世", null, 2));
+        suppliers.add(supplier("a tiger", "ti", null, 3));
+        suppliers.add(supplier("a tiger", "ige", null, 4));
+        suppliers.add(supplier("世界世", "界", null, 2));
+        suppliers.add(supplier("a tiger", "tigers", null, 0));
+        suppliers.add(supplier("a tiger", "ipa", null, 0));
+        suppliers.add(supplier("世界世", "\uD83C\uDF0D", null, 0));
+
+        // Extra assertions about 4-byte characters
         // some assertions about the supplementary (4-byte) character we'll use for testing
         assert "𠜎".length() == 2;
         assert "𠜎".codePointCount(0, 2) == 1;
-        assert "𠜎".getBytes(UTF_8).length == 4;
-
-        assertThat(process("a ti𠜎er", "𠜎er", 0), equalTo(5));
-        assertThat(process("a ti𠜎er", "i𠜎e", 0), equalTo(4));
-        assertThat(process("a ti𠜎er", "ti𠜎", 0), equalTo(3));
-        assertThat(process("a ti𠜎er", "er", 0), equalTo(6));
-        assertThat(process("a ti𠜎er", "r", 0), equalTo(7));
-
-        assertThat(process("𠜎a ti𠜎er", "𠜎er", 0), equalTo(6));
-        assertThat(process("𠜎a ti𠜎er", "i𠜎e", 0), equalTo(5));
-        assertThat(process("𠜎a ti𠜎er", "ti𠜎", 0), equalTo(4));
-        assertThat(process("𠜎a ti𠜎er", "er", 0), equalTo(7));
-        assertThat(process("𠜎a ti𠜎er", "r", 0), equalTo(8));
-
-        // exact
-        assertThat(process("a ti𠜎er", "a ti𠜎er", 0), equalTo(1));
-        assertThat(process("𠜎𠜎𠜎abc", "𠜎𠜎𠜎abc", 0), equalTo(1));
-        assertThat(process(" 𠜎𠜎𠜎abc", " 𠜎𠜎𠜎abc", 0), equalTo(1));
-        assertThat(process("𠜎𠜎𠜎 abc ", "𠜎𠜎𠜎 abc ", 0), equalTo(1));
-
+        assert "𠜎".getBytes(StandardCharsets.UTF_8).length == 4;
+        suppliers.add(supplier("a ti𠜎er", "𠜎er", null, 5));
+        suppliers.add(supplier("a ti𠜎er", "i𠜎e", null, 4));
+        suppliers.add(supplier("a ti𠜎er", "ti𠜎", null, 3));
+        suppliers.add(supplier("a ti𠜎er", "er", null, 6));
+        suppliers.add(supplier("a ti𠜎er", "r", null, 7));
+        suppliers.add(supplier("a ti𠜎er", "a ti𠜎er", null, 1));
         // prefix
-        assertThat(process("𠜎abc", "𠜎", 0), equalTo(1));
-        assertThat(process("𠜎 abc", "𠜎 ", 0), equalTo(1));
-        assertThat(process("𠜎𠜎𠜎abc", "𠜎𠜎𠜎", 0), equalTo(1));
-        assertThat(process("𠜎𠜎𠜎 abc", "𠜎𠜎𠜎 ", 0), equalTo(1));
-        assertThat(process(" 𠜎𠜎𠜎 abc", " 𠜎𠜎𠜎 ", 0), equalTo(1));
-        assertThat(process("𠜎 𠜎 𠜎 abc", "𠜎 𠜎 𠜎 ", 0), equalTo(1));
-
+        suppliers.add(supplier("𠜎abc", "𠜎", null, 1));
+        suppliers.add(supplier("𠜎 abc", "𠜎 ", null, 1));
+        suppliers.add(supplier("𠜎𠜎𠜎abc", "𠜎𠜎𠜎", null, 1));
+        suppliers.add(supplier("𠜎𠜎𠜎 abc", "𠜎𠜎𠜎 ", null, 1));
+        suppliers.add(supplier(" 𠜎𠜎𠜎 abc", " 𠜎𠜎𠜎 ", null, 1));
+        suppliers.add(supplier("𠜎 𠜎 𠜎 abc", "𠜎 𠜎 𠜎 ", null, 1));
         // suffix
-        assertThat(process("abc𠜎", "𠜎", 0), equalTo(4));
-        assertThat(process("abc 𠜎", " 𠜎", 0), equalTo(4));
-        assertThat(process("abc𠜎𠜎𠜎", "𠜎𠜎𠜎", 0), equalTo(4));
-        assertThat(process("abc 𠜎𠜎𠜎", " 𠜎𠜎𠜎", 0), equalTo(4));
-        assertThat(process("abc𠜎𠜎𠜎 ", "𠜎𠜎𠜎 ", 0), equalTo(4));
-
+        suppliers.add(supplier("abc𠜎", "𠜎", null, 4));
+        suppliers.add(supplier("abc 𠜎", " 𠜎", null, 4));
+        suppliers.add(supplier("abc𠜎𠜎𠜎", "𠜎𠜎𠜎", null, 4));
+        suppliers.add(supplier("abc 𠜎𠜎𠜎", " 𠜎𠜎𠜎", null, 4));
+        suppliers.add(supplier("abc𠜎𠜎𠜎 ", "𠜎𠜎𠜎 ", null, 4));
         // out of range
-        assertThat(process("𠜎a ti𠜎er", "𠜎a ti𠜎ers", 0), equalTo(0));
-        assertThat(process("a ti𠜎er", "aa ti𠜎er", 0), equalTo(0));
-        assertThat(process("abc𠜎𠜎", "𠜎𠜎𠜎", 0), equalTo(0));
+        suppliers.add(supplier("𠜎a ti𠜎er", "𠜎a ti𠜎ers", null, 0));
+        suppliers.add(supplier("a ti𠜎er", "aa ti𠜎er", null, 0));
+        suppliers.add(supplier("abc𠜎𠜎", "𠜎𠜎𠜎", null, 0));
 
         assert "🐱".length() == 2 && "🐶".length() == 2;
         assert "🐱".codePointCount(0, 2) == 1 && "🐶".codePointCount(0, 2) == 1;
-        assert "🐱".getBytes(UTF_8).length == 4 && "🐶".getBytes(UTF_8).length == 4;
-        assertThat(process("🐱Meow!🐶Woof!", "🐱Meow!🐶Woof!", 0), equalTo(1));
-        assertThat(process("🐱Meow!🐶Woof!", "Meow!🐶Woof!", 0), equalTo(2));
-        assertThat(process("🐱Meow!🐶Woof!", "eow!🐶Woof!", 0), equalTo(3));
+        assert "🐱".getBytes(StandardCharsets.UTF_8).length == 4 && "🐶".getBytes(StandardCharsets.UTF_8).length == 4;
+        suppliers.add(supplier("🐱Meow!🐶Woof!", "🐱Meow!🐶Woof!", null, 1));
+        suppliers.add(supplier("🐱Meow!🐶Woof!", "Meow!🐶Woof!", 0, 2));
+        suppliers.add(supplier("🐱Meow!🐶Woof!", "eow!🐶Woof!", 0, 3));
+
+        return parameterSuppliersFromTypedData(suppliers);
+    }
+
+    @Override
+    protected Expression build(Source source, List<Expression> args) {
+        return new Locate(source, args.get(0), args.get(1), args.size() < 3 ? null : args.get(2));
     }
 
-    private Integer process(String str, String substr, Integer start) {
-        try (
-            EvalOperator.ExpressionEvaluator eval = evaluator(
-                new Locate(
-                    Source.EMPTY,
-                    field("str", DataTypes.KEYWORD),
-                    field("substr", DataTypes.KEYWORD),
-                    new Literal(Source.EMPTY, start, DataTypes.INTEGER)
-                )
-            ).get(driverContext());
-            Block block = eval.eval(row(List.of(new BytesRef(str), new BytesRef(substr))))
-        ) {
-            return block.isNull(0) ? Integer.valueOf(0) : ((Integer) toJavaObject(block, 0));
+    private static TestCaseSupplier supplier(String str, String substr, @Nullable Integer start, @Nullable Integer expectedValue) {
+        String name = String.format(Locale.ROOT, "\"%s\" in \"%s\"", substr, str);
+        if (start != null) {
+            name += " starting at " + start;
         }
+
+        return new TestCaseSupplier(
+            name,
+            types(DataTypes.KEYWORD, DataTypes.KEYWORD, start != null),
+            () -> testCase(DataTypes.KEYWORD, DataTypes.KEYWORD, str, substr, start, expectedValue)
+        );
+    }
+
+    interface ExpectedValue {
+        int expectedValue(String str, String substr, Integer start);
     }
 
     private static TestCaseSupplier supplier(
         String name,
-        DataType firstType,
-        DataType secondType,
+        DataType strType,
+        DataType substrType,
         Supplier<String> strValueSupplier,
-        Supplier<String> substrValueSupplier,
-        Supplier<Integer> startSupplier
+        Function<String, String> substrValueSupplier,
+        @Nullable Supplier<Integer> startSupplier,
+        ExpectedValue expectedValue
     ) {
-        return new TestCaseSupplier(name, List.of(firstType, secondType), () -> {
-            List<TestCaseSupplier.TypedData> values = new ArrayList<>();
-            String expectedToString = "LocateEvaluator[str=Attribute[channel=0], substr=Attribute[channel=1], start=Attribute[channel=2]]";
-
-            String value = strValueSupplier.get();
-            values.add(new TestCaseSupplier.TypedData(new BytesRef(value), firstType, "0"));
+        List<DataType> types = types(strType, substrType, startSupplier != null);
+        return new TestCaseSupplier(name + TestCaseSupplier.nameFromTypes(types), types, () -> {
+            String str = strValueSupplier.get();
+            String substr = substrValueSupplier.apply(str);
+            Integer start = startSupplier == null ? null : startSupplier.get();
+            return testCase(strType, substrType, str, substr, start, expectedValue.expectedValue(str, substr, start));
+        });
+    }
 
-            String substrValue = substrValueSupplier.get();
-            values.add(new TestCaseSupplier.TypedData(new BytesRef(substrValue), secondType, "1"));
+    private static String expectedToString(boolean hasStart) {
+        if (hasStart) {
+            return "LocateEvaluator[str=Attribute[channel=0], substr=Attribute[channel=1], start=Attribute[channel=2]]";
+        }
+        return "LocateNoStartEvaluator[str=Attribute[channel=0], substr=Attribute[channel=1]]";
+    }
 
-            Integer startValue = startSupplier.get();
-            values.add(new TestCaseSupplier.TypedData(startValue, DataTypes.INTEGER, "2"));
+    private static List<DataType> types(DataType firstType, DataType secondType, boolean hasStart) {
+        List<DataType> types = new ArrayList<>();
+        types.add(firstType);
+        types.add(secondType);
+        if (hasStart) {
+            types.add(DataTypes.INTEGER);
+        }
+        return types;
+    }
 
-            int expectedValue = 1 + value.indexOf(substrValue);
-            return new TestCaseSupplier.TestCase(values, expectedToString, DataTypes.INTEGER, equalTo(expectedValue));
-        });
+    private static TestCaseSupplier.TestCase testCase(
+        DataType strType,
+        DataType substrType,
+        String str,
+        String substr,
+        Integer start,
+        Integer expectedValue
+    ) {
+        List<TestCaseSupplier.TypedData> values = new ArrayList<>();
+        values.add(new TestCaseSupplier.TypedData(str == null ? null : new BytesRef(str), strType, "str"));
+        values.add(new TestCaseSupplier.TypedData(substr == null ? null : new BytesRef(substr), substrType, "substr"));
+        if (start != null) {
+            values.add(new TestCaseSupplier.TypedData(start, DataTypes.INTEGER, "start"));
+        }
+        return new TestCaseSupplier.TestCase(values, expectedToString(start != null), DataTypes.INTEGER, equalTo(expectedValue));
     }
 }