Browse Source

[8.19] Make OptimizerExpressionRule conditional (#127753)

* Make OptimizerExpressionRule conditional (#127500)

(cherry picked from commit 7d466c9d59bbaabf91c0ede70faef6ccf17a9c2d)

* replace pattern with instanceof

* fix flakiness
Ievgen Degtiarenko 5 months ago
parent
commit
5c7bc7a29b

+ 10 - 3
x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/tree/Node.java

@@ -184,16 +184,19 @@ public abstract class Node<T extends Node<T>> implements NamedWriteable {
     public T transformDown(Function<? super T, ? extends T> rule) {
         T root = rule.apply((T) this);
         Node<T> node = this.equals(root) ? this : root;
-
         return node.transformChildren(child -> child.transformDown(rule));
     }
 
     @SuppressWarnings("unchecked")
     public <E extends T> T transformDown(Class<E> typeToken, Function<E, ? extends T> rule) {
-        // type filtering function
         return transformDown((t) -> (typeToken.isInstance(t) ? rule.apply((E) t) : t));
     }
 
+    @SuppressWarnings("unchecked")
+    public <E extends T> T transformDown(Predicate<Node<?>> nodePredicate, Function<E, ? extends T> rule) {
+        return transformDown((t) -> (nodePredicate.test(t) ? rule.apply((E) t) : t));
+    }
+
     @SuppressWarnings("unchecked")
     public T transformUp(Function<? super T, ? extends T> rule) {
         T transformed = transformChildren(child -> child.transformUp(rule));
@@ -203,10 +206,14 @@ public abstract class Node<T extends Node<T>> implements NamedWriteable {
 
     @SuppressWarnings("unchecked")
     public <E extends T> T transformUp(Class<E> typeToken, Function<E, ? extends T> rule) {
-        // type filtering function
         return transformUp((t) -> (typeToken.isInstance(t) ? rule.apply((E) t) : t));
     }
 
+    @SuppressWarnings("unchecked")
+    public <E extends T> T transformUp(Predicate<Node<?>> nodePredicate, Function<E, ? extends T> rule) {
+        return transformUp((t) -> (nodePredicate.test(t) ? rule.apply((E) t) : t));
+    }
+
     @SuppressWarnings("unchecked")
     protected <R extends Function<? super T, ? extends T>> T transformChildren(Function<T, ? extends T> traversalOperation) {
         boolean childrenChanged = false;

+ 15 - 2
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/OptimizerRules.java

@@ -7,9 +7,13 @@
 package org.elasticsearch.xpack.esql.optimizer.rules.logical;
 
 import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.tree.Node;
 import org.elasticsearch.xpack.esql.core.util.ReflectionUtils;
 import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
+import org.elasticsearch.xpack.esql.plan.logical.EsRelation;
+import org.elasticsearch.xpack.esql.plan.logical.Limit;
 import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
+import org.elasticsearch.xpack.esql.plan.logical.Project;
 import org.elasticsearch.xpack.esql.rule.ParameterizedRule;
 import org.elasticsearch.xpack.esql.rule.Rule;
 
@@ -55,12 +59,21 @@ public final class OptimizerRules {
         @Override
         public final LogicalPlan apply(LogicalPlan plan, LogicalOptimizerContext ctx) {
             return direction == TransformDirection.DOWN
-                ? plan.transformExpressionsDown(expressionTypeToken, e -> rule(e, ctx))
-                : plan.transformExpressionsUp(expressionTypeToken, e -> rule(e, ctx));
+                ? plan.transformExpressionsDown(this::shouldVisit, expressionTypeToken, e -> rule(e, ctx))
+                : plan.transformExpressionsUp(this::shouldVisit, expressionTypeToken, e -> rule(e, ctx));
         }
 
         protected abstract Expression rule(E e, LogicalOptimizerContext ctx);
 
+        /**
+         * Defines if a node should be visited or not.
+         * Allows to skip nodes that are not applicable for the rule even if they contain expressions.
+         * By default that skips FROM, LIMIT, PROJECT, KEEP and DROP but this list could be extended or replaced in subclasses.
+         */
+        protected boolean shouldVisit(Node<?> node) {
+            return (node instanceof EsRelation || node instanceof Project || node instanceof Limit) == false;
+        }
+
         public Class<E> expressionToken() {
             return expressionTypeToken;
         }

+ 21 - 14
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/QueryPlan.java

@@ -18,6 +18,7 @@ import java.util.Collection;
 import java.util.List;
 import java.util.function.Consumer;
 import java.util.function.Function;
+import java.util.function.Predicate;
 
 /**
  * There are two main types of plans, {@code LogicalPlan} and {@code PhysicalPlan}
@@ -113,22 +114,36 @@ public abstract class QueryPlan<PlanType extends QueryPlan<PlanType>> extends No
         return transformPropertiesOnly(Object.class, e -> doTransformExpression(e, exp -> exp.transformUp(typeToken, rule)));
     }
 
-    public PlanType transformExpressionsDown(Function<Expression, ? extends Expression> rule) {
-        return transformExpressionsDown(Expression.class, rule);
-    }
-
     public <E extends Expression> PlanType transformExpressionsDown(Class<E> typeToken, Function<E, ? extends Expression> rule) {
         return transformPropertiesDown(Object.class, e -> doTransformExpression(e, exp -> exp.transformDown(typeToken, rule)));
     }
 
-    public PlanType transformExpressionsUp(Function<Expression, ? extends Expression> rule) {
-        return transformExpressionsUp(Expression.class, rule);
+    public <E extends Expression> PlanType transformExpressionsDown(
+        Predicate<Node<?>> shouldVisit,
+        Class<E> typeToken,
+        Function<E, ? extends Expression> rule
+    ) {
+        return transformDown(
+            shouldVisit,
+            t -> t.transformNodeProps(Object.class, e -> doTransformExpression(e, exp -> exp.transformDown(typeToken, rule)))
+        );
     }
 
     public <E extends Expression> PlanType transformExpressionsUp(Class<E> typeToken, Function<E, ? extends Expression> rule) {
         return transformPropertiesUp(Object.class, e -> doTransformExpression(e, exp -> exp.transformUp(typeToken, rule)));
     }
 
+    public <E extends Expression> PlanType transformExpressionsUp(
+        Predicate<Node<?>> shouldVisit,
+        Class<E> typeToken,
+        Function<E, ? extends Expression> rule
+    ) {
+        return transformUp(
+            shouldVisit,
+            t -> t.transformNodeProps(Object.class, e -> doTransformExpression(e, exp -> exp.transformUp(typeToken, rule)))
+        );
+    }
+
     @SuppressWarnings("unchecked")
     private static Object doTransformExpression(Object arg, Function<Expression, ? extends Expression> traversal) {
         if (arg instanceof Expression exp) {
@@ -188,18 +203,10 @@ public abstract class QueryPlan<PlanType extends QueryPlan<PlanType>> extends No
         forEachPropertyOnly(Object.class, e -> doForEachExpression(e, exp -> exp.forEachDown(typeToken, rule)));
     }
 
-    public void forEachExpressionDown(Consumer<? super Expression> rule) {
-        forEachExpressionDown(Expression.class, rule);
-    }
-
     public <E extends Expression> void forEachExpressionDown(Class<? extends E> typeToken, Consumer<? super E> rule) {
         forEachPropertyDown(Object.class, e -> doForEachExpression(e, exp -> exp.forEachDown(typeToken, rule)));
     }
 
-    public void forEachExpressionUp(Consumer<? super Expression> rule) {
-        forEachExpressionUp(Expression.class, rule);
-    }
-
     public <E extends Expression> void forEachExpressionUp(Class<E> typeToken, Consumer<? super E> rule) {
         forEachPropertyUp(Object.class, e -> doForEachExpression(e, exp -> exp.forEachUp(typeToken, rule)));
     }

+ 37 - 16
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRulesTests.java

@@ -8,31 +8,38 @@ package org.elasticsearch.xpack.esql.optimizer;
 
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.esql.core.expression.Alias;
 import org.elasticsearch.xpack.esql.core.expression.Expression;
 import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
 import org.elasticsearch.xpack.esql.core.expression.FoldContext;
 import org.elasticsearch.xpack.esql.core.expression.Literal;
 import org.elasticsearch.xpack.esql.core.expression.Nullability;
+import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
 import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
 import org.elasticsearch.xpack.esql.core.tree.Source;
 import org.elasticsearch.xpack.esql.core.type.DataType;
-import org.elasticsearch.xpack.esql.core.util.TestUtils;
 import org.elasticsearch.xpack.esql.expression.predicate.Range;
+import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add;
+import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules;
+import org.elasticsearch.xpack.esql.parser.EsqlParser;
 
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
 
 import static org.elasticsearch.xpack.esql.EsqlTestUtils.rangeOf;
 import static org.elasticsearch.xpack.esql.core.type.DataType.BOOLEAN;
+import static org.elasticsearch.xpack.esql.core.util.TestUtils.getFieldAttribute;
 import static org.elasticsearch.xpack.esql.core.util.TestUtils.of;
+import static org.hamcrest.Matchers.containsInAnyOrder;
 
 public class OptimizerRulesTests extends ESTestCase {
 
-    private static final Literal FIVE = L(5);
-    private static final Literal SIX = L(6);
+    private static final Literal FIVE = of(5);
+    private static final Literal SIX = of(6);
 
-    public static class DummyBooleanExpression extends Expression {
+    public static final class DummyBooleanExpression extends Expression {
 
         private final int id;
 
@@ -87,21 +94,13 @@ public class OptimizerRulesTests extends ESTestCase {
         }
     }
 
-    private static Literal L(Object value) {
-        return of(value);
-    }
-
-    private static FieldAttribute getFieldAttribute() {
-        return TestUtils.getFieldAttribute("a");
-    }
-
     //
     // Range optimization
     //
 
     // 6 < a <= 5 -> FALSE
     public void testFoldExcludingRangeToFalse() {
-        FieldAttribute fa = getFieldAttribute();
+        FieldAttribute fa = getFieldAttribute("a");
 
         Range r = rangeOf(fa, SIX, false, FIVE, true);
         assertTrue(r.foldable());
@@ -110,13 +109,35 @@ public class OptimizerRulesTests extends ESTestCase {
 
     // 6 < a <= 5.5 -> FALSE
     public void testFoldExcludingRangeWithDifferentTypesToFalse() {
-        FieldAttribute fa = getFieldAttribute();
+        FieldAttribute fa = getFieldAttribute("a");
 
-        Range r = rangeOf(fa, SIX, false, L(5.5d), true);
+        Range r = rangeOf(fa, SIX, false, of(5.5d), true);
         assertTrue(r.foldable());
         assertEquals(Boolean.FALSE, r.fold(FoldContext.small()));
     }
 
-    // Conjunction
+    public void testOptimizerExpressionRuleShouldNotVisitExcludedNodes() {
+        var rule = new OptimizerRules.OptimizerExpressionRule<>(randomFrom(OptimizerRules.TransformDirection.values())) {
+            private final List<Expression> appliedTo = new ArrayList<>();
 
+            @Override
+            protected Expression rule(Expression e, LogicalOptimizerContext ctx) {
+                appliedTo.add(e);
+                return e;
+            }
+        };
+
+        rule.apply(
+            new EsqlParser().createStatement("FROM index | EVAL x=f1+1 | KEEP x, f2 | LIMIT 1"),
+            new LogicalOptimizerContext(null, FoldContext.small())
+        );
+
+        var literal = new Literal(new Source(1, 25, "1"), 1, DataType.INTEGER);
+        var attribute = new UnresolvedAttribute(new Source(1, 20, "f1"), "f1");
+        var add = new Add(new Source(1, 20, "f1+1"), attribute, literal);
+        var alias = new Alias(new Source(1, 18, "x=f1+1"), "x", add);
+
+        // contains expressions only from EVAL
+        assertThat(rule.appliedTo, containsInAnyOrder(alias, add, attribute, literal));
+    }
 }