Browse Source

ESQL: Simpify IS NULL/IS NOT NULL evaluation (#103099)

The IS NULL, IS NOT NULL predicates only care about the nullability not
 the value of an expression. Given the expression x + 1 / 2, the actual
 result does not matter only if it's null or not - that is, it only
 matters if x is null or not.
 So x + 1 / 2 IS NULL becomes x IS NULL - which can be opportunistically
 pushed down or evaluated.
Preserve the original expression to cope with under/overflow or mv
 fields

Fix #103097
Costin Leau 1 year ago
parent
commit
2c8f6eba8c

+ 6 - 0
docs/changelog/103099.yaml

@@ -0,0 +1,6 @@
+pr: 103099
+summary: "ESQL: Simpify IS NULL/IS NOT NULL evaluation"
+area: ES|QL
+type: enhancement
+issues:
+ - 103097

+ 17 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizer.java

@@ -7,10 +7,12 @@
 
 package org.elasticsearch.xpack.esql.optimizer;
 
+import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce;
 import org.elasticsearch.xpack.esql.plan.logical.Eval;
 import org.elasticsearch.xpack.esql.plan.logical.TopN;
 import org.elasticsearch.xpack.esql.stats.SearchStats;
 import org.elasticsearch.xpack.ql.expression.Alias;
+import org.elasticsearch.xpack.ql.expression.Expression;
 import org.elasticsearch.xpack.ql.expression.FieldAttribute;
 import org.elasticsearch.xpack.ql.expression.Literal;
 import org.elasticsearch.xpack.ql.expression.NamedExpression;
@@ -37,7 +39,13 @@ public class LocalLogicalPlanOptimizer extends ParameterizedRuleExecutor<Logical
 
     @Override
     protected List<Batch<LogicalPlan>> batches() {
-        var local = new Batch<>("Local rewrite", new ReplaceTopNWithLimitAndSort(), new ReplaceMissingFieldWithNull());
+        var local = new Batch<>(
+            "Local rewrite",
+            Limiter.ONCE,
+            new ReplaceTopNWithLimitAndSort(),
+            new ReplaceMissingFieldWithNull(),
+            new InferIsNotNull()
+        );
 
         var rules = new ArrayList<Batch<LogicalPlan>>();
         rules.add(local);
@@ -116,6 +124,14 @@ public class LocalLogicalPlanOptimizer extends ParameterizedRuleExecutor<Logical
         }
     }
 
+    static class InferIsNotNull extends OptimizerRules.InferIsNotNull {
+
+        @Override
+        protected boolean skipExpression(Expression e) {
+            return e instanceof Coalesce;
+        }
+    }
+
     abstract static class ParameterizedOptimizerRule<SubPlan extends LogicalPlan, P> extends ParameterizedRule<SubPlan, LogicalPlan, P> {
 
         public final LogicalPlan apply(LogicalPlan plan, P context) {

+ 40 - 0
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java

@@ -12,6 +12,7 @@ import org.elasticsearch.xpack.esql.EsqlTestUtils;
 import org.elasticsearch.xpack.esql.analysis.Analyzer;
 import org.elasticsearch.xpack.esql.analysis.AnalyzerContext;
 import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
+import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce;
 import org.elasticsearch.xpack.esql.parser.EsqlParser;
 import org.elasticsearch.xpack.esql.plan.logical.Eval;
 import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation;
@@ -21,9 +22,11 @@ import org.elasticsearch.xpack.ql.expression.Alias;
 import org.elasticsearch.xpack.ql.expression.Expressions;
 import org.elasticsearch.xpack.ql.expression.Literal;
 import org.elasticsearch.xpack.ql.expression.ReferenceAttribute;
+import org.elasticsearch.xpack.ql.expression.predicate.nulls.IsNotNull;
 import org.elasticsearch.xpack.ql.index.EsIndex;
 import org.elasticsearch.xpack.ql.index.IndexResolution;
 import org.elasticsearch.xpack.ql.plan.logical.EsRelation;
+import org.elasticsearch.xpack.ql.plan.logical.Filter;
 import org.elasticsearch.xpack.ql.plan.logical.Limit;
 import org.elasticsearch.xpack.ql.plan.logical.LogicalPlan;
 import org.elasticsearch.xpack.ql.plan.logical.Project;
@@ -35,6 +38,7 @@ import java.util.List;
 import java.util.Map;
 
 import static org.elasticsearch.xpack.esql.EsqlTestUtils.L;
+import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_SEARCH_STATS;
 import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER;
 import static org.elasticsearch.xpack.esql.EsqlTestUtils.as;
 import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
@@ -263,6 +267,38 @@ public class LocalLogicalPlanOptimizerTests extends ESTestCase {
         );
     }
 
+    public void testIsNotNullOnCoalesce() {
+        var plan = localPlan("""
+              from test
+            | where coalesce(emp_no, salary) is not null
+            """);
+
+        var limit = as(plan, Limit.class);
+        var filter = as(limit.child(), Filter.class);
+        var inn = as(filter.condition(), IsNotNull.class);
+        var coalesce = as(inn.children().get(0), Coalesce.class);
+        assertThat(Expressions.names(coalesce.children()), contains("emp_no", "salary"));
+        var source = as(filter.child(), EsRelation.class);
+    }
+
+    public void testIsNotNullOnExpression() {
+        var plan = localPlan("""
+              from test
+            | eval x = emp_no + 1
+            | where x is not null
+            """);
+
+        var limit = as(plan, Limit.class);
+        var filter = as(limit.child(), Filter.class);
+        var inn = as(filter.condition(), IsNotNull.class);
+        assertThat(Expressions.names(inn.children()), contains("x"));
+        var eval = as(filter.child(), Eval.class);
+        filter = as(eval.child(), Filter.class);
+        inn = as(filter.condition(), IsNotNull.class);
+        assertThat(Expressions.names(inn.children()), contains("emp_no"));
+        var source = as(filter.child(), EsRelation.class);
+    }
+
     private LocalRelation asEmptyRelation(Object o) {
         var empty = as(o, LocalRelation.class);
         assertThat(empty.supplier(), is(LocalSupplier.EMPTY));
@@ -285,6 +321,10 @@ public class LocalLogicalPlanOptimizerTests extends ESTestCase {
         return localPlan;
     }
 
+    private LogicalPlan localPlan(String query) {
+        return localPlan(plan(query), TEST_SEARCH_STATS);
+    }
+
     @Override
     protected List<String> filteredWarnings() {
         return withDefaultLimitWarning(super.filteredWarnings());

+ 90 - 0
x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/optimizer/OptimizerRules.java

@@ -8,6 +8,8 @@ package org.elasticsearch.xpack.ql.optimizer;
 
 import org.elasticsearch.common.util.set.Sets;
 import org.elasticsearch.xpack.ql.expression.Alias;
+import org.elasticsearch.xpack.ql.expression.Attribute;
+import org.elasticsearch.xpack.ql.expression.AttributeMap;
 import org.elasticsearch.xpack.ql.expression.Expression;
 import org.elasticsearch.xpack.ql.expression.Expressions;
 import org.elasticsearch.xpack.ql.expression.Literal;
@@ -69,6 +71,7 @@ import java.util.function.BiFunction;
 
 import static java.lang.Math.signum;
 import static java.util.Arrays.asList;
+import static java.util.Collections.emptySet;
 import static org.elasticsearch.xpack.ql.expression.Literal.FALSE;
 import static org.elasticsearch.xpack.ql.expression.Literal.TRUE;
 import static org.elasticsearch.xpack.ql.expression.predicate.Predicates.combineAnd;
@@ -1785,6 +1788,93 @@ public final class OptimizerRules {
         }
     }
 
+    /**
+     * Simplify IsNotNull targets by resolving the underlying expression to its root fields with unknown
+     * nullability.
+     * e.g.
+     * (x + 1) / 2 IS NOT NULL --> x IS NOT NULL AND (x+1) / 2 IS NOT NULL
+     * SUBSTRING(x, 3) > 4 IS NOT NULL --> x IS NOT NULL AND SUBSTRING(x, 3) > 4 IS NOT NULL
+     * When dealing with multiple fields, a conjunction/disjunction based on the predicate:
+     * (x + y) / 4 IS NOT NULL --> x IS NOT NULL AND y IS NOT NULL AND (x + y) / 4 IS NOT NULL
+     * This handles the case of fields nested inside functions or expressions in order to avoid:
+     * - having to evaluate the whole expression
+     * - not pushing down the filter due to expression evaluation
+     * IS NULL cannot be simplified since it leads to a disjunction which prevents the filter to be
+     * pushed down:
+     * (x + 1) IS NULL --> x IS NULL OR x + 1 IS NULL
+     * and x IS NULL cannot be pushed down
+     * <br/>
+     * Implementation-wise this rule goes bottom-up, keeping an alias up to date to the current plan
+     * and then looks for replacing the target.
+     */
+    public static class InferIsNotNull extends Rule<LogicalPlan, LogicalPlan> {
+
+        @Override
+        public LogicalPlan apply(LogicalPlan plan) {
+            // the alias map is shared across the whole plan
+            AttributeMap<Expression> aliases = new AttributeMap<>();
+            // traverse bottom-up to pick up the aliases as we go
+            plan = plan.transformUp(p -> inspectPlan(p, aliases));
+            return plan;
+        }
+
+        private LogicalPlan inspectPlan(LogicalPlan plan, AttributeMap<Expression> aliases) {
+            // inspect just this plan properties
+            plan.forEachExpression(Alias.class, a -> aliases.put(a.toAttribute(), a.child()));
+            // now go about finding isNull/isNotNull
+            LogicalPlan newPlan = plan.transformExpressionsOnlyUp(IsNotNull.class, inn -> inferNotNullable(inn, aliases));
+            return newPlan;
+        }
+
+        private Expression inferNotNullable(IsNotNull inn, AttributeMap<Expression> aliases) {
+            Expression result = inn;
+            Set<Expression> refs = resolveExpressionAsRootAttributes(inn.field(), aliases);
+            // no refs found or could not detect - return the original function
+            if (refs.size() > 0) {
+                // add IsNull for the filters along with the initial inn
+                var innList = CollectionUtils.combine(refs.stream().map(r -> (Expression) new IsNotNull(inn.source(), r)).toList(), inn);
+                result = Predicates.combineAnd(innList);
+            }
+            return result;
+        }
+
+        /**
+         * Unroll the expression to its references to get to the root fields
+         * that really matter for filtering.
+         */
+        protected Set<Expression> resolveExpressionAsRootAttributes(Expression exp, AttributeMap<Expression> aliases) {
+            Set<Expression> resolvedExpressions = new LinkedHashSet<>();
+            boolean changed = doResolve(exp, aliases, resolvedExpressions);
+            return changed ? resolvedExpressions : emptySet();
+        }
+
+        private boolean doResolve(Expression exp, AttributeMap<Expression> aliases, Set<Expression> resolvedExpressions) {
+            boolean changed = false;
+            // check if the expression can be skipped or is not nullabe
+            if (skipExpression(exp) || exp.nullable() == Nullability.FALSE) {
+                resolvedExpressions.add(exp);
+            } else {
+                for (Expression e : exp.references()) {
+                    Expression resolved = aliases.resolve(e, e);
+                    // found a root attribute, bail out
+                    if (resolved instanceof Attribute a && resolved == e) {
+                        resolvedExpressions.add(a);
+                        // don't mark things as change if the original expression hasn't been broken down
+                        changed |= resolved != exp;
+                    } else {
+                        // go further
+                        changed |= doResolve(resolved, aliases, resolvedExpressions);
+                    }
+                }
+            }
+            return changed;
+        }
+
+        protected boolean skipExpression(Expression e) {
+            return false;
+        }
+    }
+
     public static final class SetAsOptimized extends Rule<LogicalPlan, LogicalPlan> {
 
         @Override

+ 84 - 0
x-pack/plugin/ql/src/test/java/org/elasticsearch/xpack/ql/optimizer/OptimizerRulesTests.java

@@ -14,6 +14,7 @@ import org.elasticsearch.xpack.ql.expression.FieldAttribute;
 import org.elasticsearch.xpack.ql.expression.Literal;
 import org.elasticsearch.xpack.ql.expression.Nullability;
 import org.elasticsearch.xpack.ql.expression.function.aggregate.Count;
+import org.elasticsearch.xpack.ql.expression.function.scalar.string.StartsWith;
 import org.elasticsearch.xpack.ql.expression.predicate.BinaryOperator;
 import org.elasticsearch.xpack.ql.expression.predicate.Predicates;
 import org.elasticsearch.xpack.ql.expression.predicate.Range;
@@ -1768,7 +1769,90 @@ public class OptimizerRulesTests extends ESTestCase {
         // expected
         Filter expected = new Filter(EMPTY, new Aggregate(EMPTY, combinedFilter, emptyList(), emptyList()), aggregateCondition);
         assertEquals(expected, new PushDownAndCombineFilters().apply(fb));
+    }
+
+    public void testIsNotNullOnIsNullField() {
+        EsRelation relation = relation();
+        var fieldA = getFieldAttribute("a");
+        Expression inn = isNotNull(fieldA);
+        Filter f = new Filter(EMPTY, relation, inn);
+
+        assertEquals(f, new OptimizerRules.InferIsNotNull().apply(f));
+    }
+
+    public void testIsNotNullOnOperatorWithOneField() {
+        EsRelation relation = relation();
+        var fieldA = getFieldAttribute("a");
+        Expression inn = isNotNull(new Add(EMPTY, fieldA, ONE));
+        Filter f = new Filter(EMPTY, relation, inn);
+        Filter expected = new Filter(EMPTY, relation, new And(EMPTY, isNotNull(fieldA), inn));
+
+        assertEquals(expected, new OptimizerRules.InferIsNotNull().apply(f));
+    }
+
+    public void testIsNotNullOnOperatorWithTwoFields() {
+        EsRelation relation = relation();
+        var fieldA = getFieldAttribute("a");
+        var fieldB = getFieldAttribute("b");
+        Expression inn = isNotNull(new Add(EMPTY, fieldA, fieldB));
+        Filter f = new Filter(EMPTY, relation, inn);
+        Filter expected = new Filter(EMPTY, relation, new And(EMPTY, new And(EMPTY, isNotNull(fieldA), isNotNull(fieldB)), inn));
+
+        assertEquals(expected, new OptimizerRules.InferIsNotNull().apply(f));
+    }
+
+    public void testIsNotNullOnFunctionWithOneField() {
+        EsRelation relation = relation();
+        var fieldA = getFieldAttribute("a");
+        var pattern = L("abc");
+        Expression inn = isNotNull(
+            new And(EMPTY, new TestStartsWith(EMPTY, fieldA, pattern, false), greaterThanOf(new Add(EMPTY, ONE, TWO), THREE))
+        );
+
+        Filter f = new Filter(EMPTY, relation, inn);
+        Filter expected = new Filter(EMPTY, relation, new And(EMPTY, isNotNull(fieldA), inn));
+
+        assertEquals(expected, new OptimizerRules.InferIsNotNull().apply(f));
+    }
+
+    public void testIsNotNullOnFunctionWithTwoFields() {
+        EsRelation relation = relation();
+        var fieldA = getFieldAttribute("a");
+        var fieldB = getFieldAttribute("b");
+        var pattern = L("abc");
+        Expression inn = isNotNull(new TestStartsWith(EMPTY, fieldA, fieldB, false));
+
+        Filter f = new Filter(EMPTY, relation, inn);
+        Filter expected = new Filter(EMPTY, relation, new And(EMPTY, new And(EMPTY, isNotNull(fieldA), isNotNull(fieldB)), inn));
+
+        assertEquals(expected, new OptimizerRules.InferIsNotNull().apply(f));
+    }
+
+    public static class TestStartsWith extends StartsWith {
+
+        public TestStartsWith(Source source, Expression input, Expression pattern, boolean caseInsensitive) {
+            super(source, input, pattern, caseInsensitive);
+        }
+
+        @Override
+        public Expression replaceChildren(List<Expression> newChildren) {
+            return new TestStartsWith(source(), newChildren.get(0), newChildren.get(1), isCaseInsensitive());
+        }
+
+        @Override
+        protected NodeInfo<TestStartsWith> info() {
+            return NodeInfo.create(this, TestStartsWith::new, input(), pattern(), isCaseInsensitive());
+        }
+    }
+
+    public void testIsNotNullOnFunctionWithTwoField() {}
+
+    private IsNotNull isNotNull(Expression field) {
+        return new IsNotNull(EMPTY, field);
+    }
 
+    private IsNull isNull(Expression field) {
+        return new IsNull(EMPTY, field);
     }
 
     private Literal nullOf(DataType dataType) {