Browse Source

add tests for SimplifyComparisonArithmetics optimization rule (#108744)

This adds in the tests from OptimizerRunTests in SQL to apply to ESQL. I've opened issues and applied the AwaitsFix annotation for those of the tests that are currently failing.
Mark Tozzi 1 year ago
parent
commit
9f54d9a804

+ 4 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/EsqlBinaryComparison.java

@@ -86,6 +86,10 @@ public abstract class EsqlBinaryComparison extends BinaryComparison implements E
             throw new IOException("No BinaryComparisonOperation found for id [" + id + "]");
         }
 
+        public String symbol() {
+            return symbol;
+        }
+
         public EsqlBinaryComparison buildNewInstance(Source source, Expression lhs, Expression rhs) {
             return constructor.apply(source, lhs, rhs);
         }

+ 220 - 1
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java

@@ -10,6 +10,7 @@ package org.elasticsearch.xpack.esql.optimizer;
 import org.elasticsearch.common.logging.LoggerMessageFormat;
 import org.elasticsearch.common.lucene.BytesRefs;
 import org.elasticsearch.compute.aggregation.QuantileStates;
+import org.elasticsearch.core.Tuple;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.esql.EsqlTestUtils;
 import org.elasticsearch.xpack.esql.TestBlockFactory;
@@ -65,6 +66,7 @@ import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul
 import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Neg;
 import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Sub;
 import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
+import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison;
 import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan;
 import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual;
 import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In;
@@ -95,19 +97,23 @@ import org.elasticsearch.xpack.ql.expression.Literal;
 import org.elasticsearch.xpack.ql.expression.NamedExpression;
 import org.elasticsearch.xpack.ql.expression.Nullability;
 import org.elasticsearch.xpack.ql.expression.ReferenceAttribute;
+import org.elasticsearch.xpack.ql.expression.UnresolvedAttribute;
 import org.elasticsearch.xpack.ql.expression.predicate.Predicates;
 import org.elasticsearch.xpack.ql.expression.predicate.logical.And;
 import org.elasticsearch.xpack.ql.expression.predicate.logical.Or;
 import org.elasticsearch.xpack.ql.expression.predicate.nulls.IsNotNull;
 import org.elasticsearch.xpack.ql.expression.predicate.nulls.IsNull;
+import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparison;
 import org.elasticsearch.xpack.ql.expression.predicate.regex.RLikePattern;
 import org.elasticsearch.xpack.ql.expression.predicate.regex.WildcardPattern;
 import org.elasticsearch.xpack.ql.index.EsIndex;
 import org.elasticsearch.xpack.ql.index.IndexResolution;
+import org.elasticsearch.xpack.ql.optimizer.OptimizerRules;
 import org.elasticsearch.xpack.ql.plan.logical.Filter;
 import org.elasticsearch.xpack.ql.plan.logical.Limit;
 import org.elasticsearch.xpack.ql.plan.logical.LogicalPlan;
 import org.elasticsearch.xpack.ql.plan.logical.OrderBy;
+import org.elasticsearch.xpack.ql.plan.logical.UnaryPlan;
 import org.elasticsearch.xpack.ql.tree.Source;
 import org.elasticsearch.xpack.ql.type.DataType;
 import org.elasticsearch.xpack.ql.type.DataTypes;
@@ -117,6 +123,7 @@ import org.elasticsearch.xpack.ql.util.StringUtils;
 import org.junit.BeforeClass;
 
 import java.lang.reflect.Constructor;
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
@@ -137,6 +144,11 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
 import static org.elasticsearch.xpack.esql.EsqlTestUtils.localSource;
 import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning;
 import static org.elasticsearch.xpack.esql.analysis.Analyzer.NO_FIELDS;
+import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.EQ;
+import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.GT;
+import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.GTE;
+import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.LT;
+import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.LTE;
 import static org.elasticsearch.xpack.esql.type.EsqlDataTypes.GEO_POINT;
 import static org.elasticsearch.xpack.esql.type.EsqlDataTypes.GEO_SHAPE;
 import static org.elasticsearch.xpack.ql.TestUtils.getFieldAttribute;
@@ -173,16 +185,18 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
     private static final Literal ONE = L(1);
     private static final Literal TWO = L(2);
     private static final Literal THREE = L(3);
-
     private static EsqlParser parser;
     private static Analyzer analyzer;
     private static LogicalPlanOptimizer logicalOptimizer;
     private static Map<String, EsField> mapping;
     private static Map<String, EsField> mappingAirports;
+    private static Map<String, EsField> mappingTypes;
     private static Analyzer analyzerAirports;
+    private static Analyzer analyzerTypes;
     private static Map<String, EsField> mappingExtra;
     private static Analyzer analyzerExtra;
     private static EnrichResolution enrichResolution;
+    private static final OptimizerRules.LiteralsOnTheRight LITERALS_ON_THE_RIGHT = new OptimizerRules.LiteralsOnTheRight();
 
     private static class SubstitutionOnlyOptimizer extends LogicalPlanOptimizer {
         static SubstitutionOnlyOptimizer INSTANCE = new SubstitutionOnlyOptimizer(new LogicalOptimizerContext(EsqlTestUtils.TEST_CFG));
@@ -222,6 +236,15 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
             TEST_VERIFIER
         );
 
+        // Some tests need additional types, so we load that index here and use it in the plan_types() function.
+        mappingTypes = loadMapping("mapping-all-types.json");
+        EsIndex types = new EsIndex("types", mappingTypes, Set.of("types"));
+        IndexResolution getIndexResultTypes = IndexResolution.valid(types);
+        analyzerTypes = new Analyzer(
+            new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResultTypes, enrichResolution),
+            TEST_VERIFIER
+        );
+
         // Some tests use mappings from mapping-extra.json to be able to test more types so we load it here
         mappingExtra = loadMapping("mapping-extra.json");
         EsIndex extra = new EsIndex("extra", mappingExtra, Set.of("extra"));
@@ -4438,11 +4461,207 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
         return optimized;
     }
 
+    private LogicalPlan planTypes(String query) {
+        return logicalOptimizer.optimize(analyzerTypes.analyze(parser.createStatement(query)));
+    }
+
+    private EsqlBinaryComparison extractPlannedBinaryComparison(String expression) {
+        LogicalPlan plan = planTypes("FROM types | WHERE " + expression);
+
+        return extractPlannedBinaryComparison(plan);
+    }
+
+    private static EsqlBinaryComparison extractPlannedBinaryComparison(LogicalPlan plan) {
+        assertTrue("Expected unary plan, found [" + plan + "]", plan instanceof UnaryPlan);
+        UnaryPlan unaryPlan = (UnaryPlan) plan;
+        assertTrue("Epxected top level Filter, foung [" + unaryPlan.child().toString() + "]", unaryPlan.child() instanceof Filter);
+        Filter filter = (Filter) unaryPlan.child();
+        assertTrue(
+            "Expected filter condition to be a binary comparison but found [" + filter.condition() + "]",
+            filter.condition() instanceof EsqlBinaryComparison
+        );
+        return (EsqlBinaryComparison) filter.condition();
+    }
+
+    private void doTestSimplifyComparisonArithmetics(
+        String expression,
+        String fieldName,
+        EsqlBinaryComparison.BinaryComparisonOperation opType,
+        Object bound
+    ) {
+        EsqlBinaryComparison bc = extractPlannedBinaryComparison(expression);
+        assertEquals(opType, bc.getFunctionType());
+
+        assertTrue(
+            "Expected left side of comparison to be a field attribute but found [" + bc.left() + "]",
+            bc.left() instanceof FieldAttribute
+        );
+        FieldAttribute attribute = (FieldAttribute) bc.left();
+        assertEquals(fieldName, attribute.name());
+
+        assertTrue("Expected right side of comparison to be a literal but found [" + bc.right() + "]", bc.right() instanceof Literal);
+        Literal literal = (Literal) bc.right();
+        assertEquals(bound, literal.value());
+    }
+
+    private void assertSemanticMatching(String expected, String provided) {
+        BinaryComparison bc = extractPlannedBinaryComparison(provided);
+        LogicalPlan exp = analyzerTypes.analyze(parser.createStatement("FROM types | WHERE " + expected));
+        assertSemanticMatching(bc, extractPlannedBinaryComparison(exp));
+    }
+
+    private static void assertSemanticMatching(Expression fieldAttributeExp, Expression unresolvedAttributeExp) {
+        Expression unresolvedUpdated = unresolvedAttributeExp.transformUp(
+            LITERALS_ON_THE_RIGHT.expressionToken(),
+            LITERALS_ON_THE_RIGHT::rule
+        ).transformUp(x -> x.foldable() ? new Literal(x.source(), x.fold(), x.dataType()) : x);
+
+        List<Expression> resolvedFields = fieldAttributeExp.collectFirstChildren(x -> x instanceof FieldAttribute);
+        for (Expression field : resolvedFields) {
+            FieldAttribute fa = (FieldAttribute) field;
+            unresolvedUpdated = unresolvedUpdated.transformDown(UnresolvedAttribute.class, x -> x.name().equals(fa.name()) ? fa : x);
+        }
+
+        assertTrue(unresolvedUpdated.semanticEquals(fieldAttributeExp));
+    }
+
+    private Expression getComparisonFromLogicalPlan(LogicalPlan plan) {
+        List<Expression> expressions = new ArrayList<>();
+        plan.forEachExpression(Expression.class, expressions::add);
+        return expressions.get(0);
+    }
+
+    private void assertNotSimplified(String comparison) {
+        String query = "FROM types | WHERE " + comparison;
+        Expression optimized = getComparisonFromLogicalPlan(planTypes(query));
+        Expression raw = getComparisonFromLogicalPlan(analyzerTypes.analyze(parser.createStatement(query)));
+
+        assertTrue(raw.semanticEquals(optimized));
+    }
+
+    private static String randomBinaryComparison() {
+        return randomFrom(EsqlBinaryComparison.BinaryComparisonOperation.values()).symbol();
+    }
+
+    public void testSimplifyComparisonArithmeticCommutativeVsNonCommutativeOps() {
+        doTestSimplifyComparisonArithmetics("integer + 2 > 3", "integer", GT, 1);
+        doTestSimplifyComparisonArithmetics("2 + integer > 3", "integer", GT, 1);
+        doTestSimplifyComparisonArithmetics("integer - 2 > 3", "integer", GT, 5);
+        doTestSimplifyComparisonArithmetics("2 - integer > 3", "integer", LT, -1);
+        doTestSimplifyComparisonArithmetics("integer * 2 > 4", "integer", GT, 2);
+        doTestSimplifyComparisonArithmetics("2 * integer > 4", "integer", GT, 2);
+
+    }
+
+    @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108388")
+    public void testSimplifyComparisonArithmeticsWithFloatingPoints() {
+        doTestSimplifyComparisonArithmetics("float / 2 > 4", "float", GT, 8d);
+    }
+
+    public void testAssertSemanticMatching() {
+        // This test is just to verify that the complicated assert logic is working on a known-good case
+        assertSemanticMatching("integer > 1", "integer + 2 > 3");
+    }
+
+    public void testSimplyComparisonArithmeticWithUnfoldedProd() {
+        assertSemanticMatching("integer * integer >= 3", "((integer * integer + 1) * 2 - 4) * 4 >= 16");
+    }
+
+    @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108524")
+    public void testSimplifyComparisionArithmetics_floatDivision() {
+        doTestSimplifyComparisonArithmetics("2 / float < 4", "float", GT, .5);
+    }
+
+    public void testSimplifyComparisonArithmeticWithMultipleOps() {
+        // i >= 3
+        doTestSimplifyComparisonArithmetics("((integer + 1) * 2 - 4) * 4 >= 16", "integer", GTE, 3);
+    }
+
+    @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108743")
+    public void testSimplifyComparisonArithmeticWithFieldNegation() {
+        doTestSimplifyComparisonArithmetics("12 * (-integer - 5) >= -120", "integer", LTE, 5);
+    }
+
+    @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108743")
+    public void testSimplifyComparisonArithmeticWithFieldDoubleNegation() {
+        doTestSimplifyComparisonArithmetics("12 * -(-integer - 5) <= 120", "integer", LTE, 5);
+    }
+
+    @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108743")
+    public void testSimplifyComparisonArithmeticWithConjunction() {
+        doTestSimplifyComparisonArithmetics("12 * (-integer - 5) == -120 AND integer < 6 ", "integer", EQ, 5);
+    }
+
+    @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108525")
+    public void testSimplifyComparisonArithmeticWithDisjunction() {
+        doTestSimplifyComparisonArithmetics("12 * (-integer - 5) >= -120 OR integer < 5", "integer", LTE, 5);
+    }
+
+    @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108388")
+    public void testSimplifyComparisonArithmeticWithFloatsAndDirectionChange() {
+        doTestSimplifyComparisonArithmetics("float / -2 < 4", "float", GT, -8d);
+        doTestSimplifyComparisonArithmetics("float * -2 < 4", "float", GT, -2d);
+    }
+
     private void assertNullLiteral(Expression expression) {
         assertEquals(Literal.class, expression.getClass());
         assertNull(expression.fold());
     }
 
+    @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108519")
+    public void testSimplifyComparisonArithmeticSkippedOnIntegerArithmeticalOverflow() {
+        assertNotSimplified("integer - 1 " + randomBinaryComparison() + " " + Long.MAX_VALUE);
+        assertNotSimplified("1 - integer " + randomBinaryComparison() + " " + Long.MIN_VALUE);
+        assertNotSimplified("integer - 1 " + randomBinaryComparison() + " " + Integer.MAX_VALUE);
+        assertNotSimplified("1 - integer " + randomBinaryComparison() + " " + Integer.MIN_VALUE);
+    }
+
+    @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108519")
+    public void testSimplifyComparisonArithmeticSkippedOnNegatingOverflow() {
+        assertNotSimplified("-integer " + randomBinaryComparison() + " " + Long.MIN_VALUE);
+        assertNotSimplified("-integer " + randomBinaryComparison() + " " + Integer.MIN_VALUE);
+    }
+
+    @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108519")
+    public void testSimplifyComparisonArithmeticSkippedOnDateOverflow() {
+        assertNotSimplified("date - 999999999 years > to_datetime(\"2010-01-01T01:01:01\")");
+        assertNotSimplified("date + -999999999 years > to_datetime(\"2010-01-01T01:01:01\")");
+    }
+
+    public void testSimplifyComparisonArithmeticSkippedOnMulDivByZero() {
+        assertNotSimplified("float / 0 " + randomBinaryComparison() + " 1");
+        assertNotSimplified("float * 0 " + randomBinaryComparison() + " 1");
+        assertNotSimplified("integer / 0 " + randomBinaryComparison() + " 1");
+        assertNotSimplified("integer * 0 " + randomBinaryComparison() + " 1");
+    }
+
+    public void testSimplifyComparisonArithmeticSkippedOnDiv() {
+        assertNotSimplified("integer / 4 " + randomBinaryComparison() + " 1");
+        assertNotSimplified("4 / integer " + randomBinaryComparison() + " 1");
+    }
+
+    public void testSimplifyComparisonArithmeticSkippedOnResultingFloatLiteral() {
+        assertNotSimplified("integer * 2 " + randomBinaryComparison() + " 3");
+    }
+
+    public void testSimplifyComparisonArithmeticSkippedOnFloatFieldWithPlusMinus() {
+        assertNotSimplified("float + 4 " + randomBinaryComparison() + " 1");
+        assertNotSimplified("4 + float " + randomBinaryComparison() + " 1");
+        assertNotSimplified("float - 4 " + randomBinaryComparison() + " 1");
+        assertNotSimplified("4 - float " + randomBinaryComparison() + " 1");
+    }
+
+    public void testSimplifyComparisonArithmeticSkippedOnFloats() {
+        for (String field : List.of("integer", "float")) {
+            for (Tuple<? extends Number, ? extends Number> nr : List.of(new Tuple<>(.4, 1), new Tuple<>(1, .4))) {
+                assertNotSimplified(field + " + " + nr.v1() + " " + randomBinaryComparison() + " " + nr.v2());
+                assertNotSimplified(field + " - " + nr.v1() + " " + randomBinaryComparison() + " " + nr.v2());
+                assertNotSimplified(nr.v1() + " + " + field + " " + randomBinaryComparison() + " " + nr.v2());
+                assertNotSimplified(nr.v1() + " - " + field + " " + randomBinaryComparison() + " " + nr.v2());
+            }
+        }
+    }
+
     public static WildcardLike wildcardLike(Expression left, String exp) {
         return new WildcardLike(EMPTY, left, new WildcardPattern(exp));
     }