浏览代码

SQL: Extend the optimisations for equalities (#50792)

* Extend the optimizations for equalities

This commit supplements the optimisations of equalities in conjunctions
and disjunctions:
* for conjunctions, the existing optimizations with ranges are extended
with not-equalities and inequalities; these lead to a fast resolution,
the conjunction either being evaluate to a FALSE, or the non-equality
conditions being dropped as superfluous;
* optimisations for disjunctions are added to be applied against ranges,
inequalities and not-equalities; these lead to disjunction either
becoming TRUE or the equality being dropped, either as superfluous or
merged into a range/inequality.

* Adress review notes

* Fix the bug around wrongly optimizing 'a=2 OR a!=?', which only yields
TRUE for same values in equality and inequality.
* Var renamings, code style adjustments, comments corrections.

* Address further review comments. Extend optim.

- fix a few code comments;
- extend the Equals OR NotEquals optimitsation (a=2 OR a!=5 -> a!=5);
- extend the Equals OR Range optimisation on limits equality (a=2 OR
  2<=a<5 -> 2<=a<5);
- in case an equality is being removed in a conjunction, the rest of
  possible optimisations to test is now skipped.

* rename one var for better legiblity

- s/rmEqual/removeEquals
Bogdan Pintea 5 年之前
父节点
当前提交
62e7c6a010

+ 221 - 22
x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/optimizer/Optimizer.java

@@ -84,6 +84,7 @@ import org.elasticsearch.xpack.sql.session.SingletonExecutable;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.HashMap;
+import java.util.Iterator;
 import java.util.LinkedHashMap;
 import java.util.LinkedHashSet;
 import java.util.LinkedList;
@@ -814,7 +815,7 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
                 }
 
                 //
-                // common factor extraction -> (a || b) && (a || c) => a && (b || c)
+                // common factor extraction -> (a || b) && (a || c) => a || (b && c)
                 //
                 List<Expression> leftSplit = splitOr(l);
                 List<Expression> rightSplit = splitOr(r);
@@ -852,7 +853,7 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
                 }
 
                 //
-                // common factor extraction -> (a && b) || (a && c) => a || (b & c)
+                // common factor extraction -> (a && b) || (a && c) => a && (b || c)
                 //
                 List<Expression> leftSplit = splitAnd(l);
                 List<Expression> rightSplit = splitAnd(r);
@@ -958,9 +959,11 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
     }
 
     /**
-     * Propagate Equals to eliminate conjuncted Ranges.
-     * When encountering a different Equals or non-containing {@link Range}, the conjunction becomes false.
-     * When encountering a containing {@link Range}, the range gets eliminated by the equality.
+     * Propagate Equals to eliminate conjuncted Ranges or BinaryComparisons.
+     * When encountering a different Equals, non-containing {@link Range} or {@link BinaryComparison}, the conjunction becomes false.
+     * When encountering a containing {@link Range}, {@link BinaryComparison} or {@link NotEquals}, these get eliminated by the equality.
+     *
+     * Since this rule can eliminate Ranges and BinaryComparisons, it should be applied before {@link CombineBinaryComparisons}.
      *
      * This rule doesn't perform any promotion of {@link BinaryComparison}s, that is handled by
      * {@link CombineBinaryComparisons} on purpose as the resulting Range might be foldable
@@ -976,6 +979,8 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
         protected Expression rule(Expression e) {
             if (e instanceof And) {
                 return propagate((And) e);
+            } else if (e instanceof Or) {
+                return propagate((Or) e);
             }
             return e;
         }
@@ -983,7 +988,11 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
         // combine conjunction
         private Expression propagate(And and) {
             List<Range> ranges = new ArrayList<>();
+            // Only equalities, not-equalities and inequalities with a foldable .right are extracted separately;
+            // the others go into the general 'exps'.
             List<BinaryComparison> equals = new ArrayList<>();
+            List<NotEquals> notEquals = new ArrayList<>();
+            List<BinaryComparison> inequalities = new ArrayList<>();
             List<Expression> exps = new ArrayList<>();
 
             boolean changed = false;
@@ -996,24 +1005,35 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
                     // equals on different values evaluate to FALSE
                     if (otherEq.right().foldable()) {
                         for (BinaryComparison eq : equals) {
-                            // cannot evaluate equals so skip it
-                            if (!eq.right().foldable()) {
-                                continue;
-                            }
                             if (otherEq.left().semanticEquals(eq.left())) {
-                                if (eq.right().foldable() && otherEq.right().foldable()) {
-                                    Integer comp = BinaryComparison.compare(eq.right().fold(), otherEq.right().fold());
-                                    if (comp != null) {
-                                        // var cannot be equal to two different values at the same time
-                                        if (comp != 0) {
-                                            return FALSE;
-                                        }
+                                Integer comp = BinaryComparison.compare(eq.right().fold(), otherEq.right().fold());
+                                if (comp != null) {
+                                    // var cannot be equal to two different values at the same time
+                                    if (comp != 0) {
+                                        return FALSE;
                                     }
                                 }
                             }
                         }
+                        equals.add(otherEq);
+                    } else {
+                        exps.add(otherEq);
+                    }
+                } else if (ex instanceof GreaterThan || ex instanceof GreaterThanOrEqual ||
+                    ex instanceof LessThan || ex instanceof LessThanOrEqual) {
+                    BinaryComparison bc = (BinaryComparison) ex;
+                    if (bc.right().foldable()) {
+                        inequalities.add(bc);
+                    } else {
+                        exps.add(ex);
+                    }
+                } else if (ex instanceof NotEquals) {
+                    NotEquals otherNotEq = (NotEquals) ex;
+                    if (otherNotEq.right().foldable()) {
+                        notEquals.add(otherNotEq);
+                    } else {
+                        exps.add(ex);
                     }
-                    equals.add(otherEq);
                 } else {
                     exps.add(ex);
                 }
@@ -1021,10 +1041,6 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
 
             // check
             for (BinaryComparison eq : equals) {
-                // cannot evaluate equals so skip it
-                if (!eq.right().foldable()) {
-                    continue;
-                }
                 Object eqValue = eq.right().fold();
 
                 for (int i = 0; i < ranges.size(); i++) {
@@ -1060,9 +1076,192 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
                         changed = true;
                     }
                 }
+
+                // evaluate all NotEquals against the Equal
+                for (Iterator<NotEquals> iter = notEquals.iterator(); iter.hasNext(); ) {
+                    NotEquals neq = iter.next();
+                    if (eq.left().semanticEquals(neq.left())) {
+                        Integer comp = BinaryComparison.compare(eqValue, neq.right().fold());
+                        if (comp != null) {
+                            if (comp == 0) {
+                                return FALSE; // clashing and conflicting: a = 1 AND a != 1
+                            } else {
+                                iter.remove(); // clashing and redundant: a = 1 AND a != 2
+                                changed = true;
+                            }
+                        }
+                    }
+                }
+
+                // evaluate all inequalities against the Equal
+                for (Iterator<BinaryComparison> iter = inequalities.iterator(); iter.hasNext(); ) {
+                    BinaryComparison bc = iter.next();
+                    if (eq.left().semanticEquals(bc.left())) {
+                        Integer compare = BinaryComparison.compare(eqValue, bc.right().fold());
+                        if (compare != null) {
+                            if (bc instanceof LessThan || bc instanceof LessThanOrEqual) { // a = 2 AND a </<= ?
+                                if ((compare == 0 && bc instanceof LessThan) || // a = 2 AND a < 2
+                                    0 < compare) { // a = 2 AND a </<= 1
+                                    return FALSE;
+                                }
+                            } else if (bc instanceof GreaterThan || bc instanceof GreaterThanOrEqual) { // a = 2 AND a >/>= ?
+                                if ((compare == 0 && bc instanceof GreaterThan) || // a = 2 AND a > 2
+                                    compare < 0) { // a = 2 AND a >/>= 3
+                                    return FALSE;
+                                }
+                            }
+
+                            iter.remove();
+                            changed = true;
+                        }
+                    }
+                }
+            }
+
+            return changed ? Predicates.combineAnd(CollectionUtils.combine(exps, equals, notEquals, inequalities, ranges)) : and;
+        }
+
+        // combine disjunction:
+        // a = 2 OR a > 3 -> nop; a = 2 OR a > 1 -> a > 1
+        // a = 2 OR a < 3 -> a < 3; a = 2 OR a < 1 -> nop
+        // a = 2 OR 3 < a < 5 -> nop; a = 2 OR 1 < a < 3 -> 1 < a < 3; a = 2 OR 0 < a < 1 -> nop
+        // a = 2 OR a != 2 -> TRUE; a = 2 OR a = 5 -> nop; a = 2 OR a != 5 -> a != 5
+        private Expression propagate(Or or) {
+            List<Expression> exps = new ArrayList<>();
+            List<Equals> equals = new ArrayList<>(); // foldable right term Equals
+            List<NotEquals> notEquals = new ArrayList<>(); // foldable right term NotEquals
+            List<Range> ranges = new ArrayList<>();
+            List<BinaryComparison> inequalities = new ArrayList<>(); // foldable right term (=limit) BinaryComparision
+
+            // split expressions by type
+            for (Expression ex : Predicates.splitOr(or)) {
+                if (ex instanceof Equals) {
+                    Equals eq = (Equals) ex;
+                    if (eq.right().foldable()) {
+                        equals.add(eq);
+                    } else {
+                        exps.add(ex);
+                    }
+                } else if (ex instanceof NotEquals) {
+                    NotEquals neq = (NotEquals) ex;
+                    if (neq.right().foldable()) {
+                        notEquals.add(neq);
+                    } else {
+                        exps.add(ex);
+                    }
+                } else if (ex instanceof Range) {
+                    ranges.add((Range) ex);
+                } else if (ex instanceof BinaryComparison) {
+                    BinaryComparison bc = (BinaryComparison) ex;
+                    if (bc.right().foldable()) {
+                        inequalities.add(bc);
+                    } else {
+                        exps.add(ex);
+                    }
+                } else {
+                    exps.add(ex);
+                }
+            }
+
+            boolean updated = false; // has the expression been modified?
+
+            // evaluate the impact of each Equal over the different types of Expressions
+            for (Iterator<Equals> iterEq = equals.iterator(); iterEq.hasNext(); ) {
+                Equals eq = iterEq.next();
+                Object eqValue = eq.right().fold();
+                boolean removeEquals = false;
+
+                // Equals OR NotEquals
+                for (NotEquals neq : notEquals) {
+                    if (eq.left().semanticEquals(neq.left())) { // a = 2 OR a != ? -> ...
+                        Integer comp = BinaryComparison.compare(eqValue, neq.right().fold());
+                        if (comp != null) {
+                            if (comp == 0) { // a = 2 OR a != 2 -> TRUE
+                                return TRUE;
+                            } else { // a = 2 OR a != 5 -> a != 5
+                                removeEquals = true;
+                                break;
+                            }
+                        }
+                    }
+                }
+                if (removeEquals) {
+                    iterEq.remove();
+                    updated = true;
+                    continue;
+                }
+
+                // Equals OR Range
+                for (int i = 0; i < ranges.size(); i ++) { // might modify list, so use index loop
+                    Range range = ranges.get(i);
+                    if (eq.left().semanticEquals(range.value())) {
+                        Integer lowerComp = range.lower().foldable() ? BinaryComparison.compare(eqValue, range.lower().fold()) : null;
+                        Integer upperComp = range.upper().foldable() ? BinaryComparison.compare(eqValue, range.upper().fold()) : null;
+
+                        if (lowerComp != null && lowerComp == 0) {
+                            if (!range.includeLower()) { // a = 2 OR 2 < a < ? -> 2 <= a < ?
+                                ranges.set(i, new Range(range.source(), range.value(), range.lower(), true,
+                                    range.upper(), range.includeUpper()));
+                            } // else : a = 2 OR 2 <= a < ? -> 2 <= a < ?
+                            removeEquals = true; // update range with lower equality instead or simply superfluous
+                            break;
+                        } else if (upperComp != null && upperComp == 0) {
+                            if (!range.includeUpper()) { // a = 2 OR ? < a < 2 -> ? < a <= 2
+                                ranges.set(i, new Range(range.source(), range.value(), range.lower(), range.includeLower(),
+                                    range.upper(), true));
+                            } // else : a = 2 OR ? < a <= 2 -> ? < a <= 2
+                            removeEquals = true; // update range with upper equality instead
+                            break;
+                        } else if (lowerComp != null && upperComp != null) {
+                            if (0 < lowerComp && upperComp < 0) { // a = 2 OR 1 < a < 3
+                                removeEquals = true; // equality is superfluous
+                                break;
+                            }
+                        }
+                    }
+                }
+                if (removeEquals) {
+                    iterEq.remove();
+                    updated = true;
+                    continue;
+                }
+
+                // Equals OR Inequality
+                for (int i = 0; i < inequalities.size(); i ++) {
+                    BinaryComparison bc = inequalities.get(i);
+                    if (eq.left().semanticEquals(bc.left())) {
+                        Integer comp = BinaryComparison.compare(eqValue, bc.right().fold());
+                        if (comp != null) {
+                            if (bc instanceof GreaterThan || bc instanceof GreaterThanOrEqual) {
+                                if (comp < 0) { // a = 1 OR a > 2 -> nop
+                                    continue;
+                                } else if (comp == 0 && bc instanceof GreaterThan) { // a = 2 OR a > 2 -> a >= 2
+                                    inequalities.set(i, new GreaterThanOrEqual(bc.source(), bc.left(), bc.right()));
+                                } // else (0 < comp || bc instanceof GreaterThanOrEqual) :
+                                // a = 3 OR a > 2 -> a > 2; a = 2 OR a => 2 -> a => 2
+
+                                removeEquals = true; // update range with equality instead or simply superfluous
+                                break;
+                            } else if (bc instanceof LessThan || bc instanceof LessThanOrEqual) {
+                                if (comp > 0) { // a = 2 OR a < 1 -> nop
+                                    continue;
+                                }
+                                if (comp == 0 && bc instanceof LessThan) { // a = 2 OR a < 2 -> a <= 2
+                                    inequalities.set(i, new LessThanOrEqual(bc.source(), bc.left(), bc.right()));
+                                } // else (comp < 0 || bc instanceof LessThanOrEqual) : a = 2 OR a < 3 -> a < 3; a = 2 OR a <= 2 -> a <= 2
+                                removeEquals = true; // update range with equality instead or simply superfluous
+                                break;
+                            }
+                        }
+                    }
+                }
+                if (removeEquals) {
+                    iterEq.remove();
+                    updated = true;
+                }
             }
 
-            return changed ? Predicates.combineAnd(CollectionUtils.combine(exps, equals, ranges)) : and;
+            return updated ? Predicates.combineOr(CollectionUtils.combine(exps, equals, notEquals, inequalities, ranges)) : or;
         }
     }
 

+ 246 - 0
x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/optimizer/OptimizerTests.java

@@ -22,6 +22,7 @@ import org.elasticsearch.xpack.ql.expression.function.Function;
 import org.elasticsearch.xpack.ql.expression.function.aggregate.AggregateFunction;
 import org.elasticsearch.xpack.ql.expression.function.aggregate.InnerAggregate;
 import org.elasticsearch.xpack.ql.expression.predicate.BinaryOperator;
+import org.elasticsearch.xpack.ql.expression.predicate.Predicates;
 import org.elasticsearch.xpack.ql.expression.predicate.Range;
 import org.elasticsearch.xpack.ql.expression.predicate.conditional.ArbitraryConditionalFunction;
 import org.elasticsearch.xpack.ql.expression.predicate.conditional.Case;
@@ -1550,6 +1551,251 @@ public class OptimizerTests extends ESTestCase {
         assertEquals(FALSE, rule.rule(exp));
     }
 
+    // a != 3 AND a = 3 -> FALSE
+    public void testPropagateEquals_VarNeq3AndVarEq3() {
+        FieldAttribute fa = getFieldAttribute();
+        NotEquals neq = new NotEquals(EMPTY, fa, THREE);
+        Equals eq = new Equals(EMPTY, fa, THREE);
+
+        PropagateEquals rule = new PropagateEquals();
+        Expression exp = rule.rule(new And(EMPTY, neq, eq));
+        assertEquals(FALSE, rule.rule(exp));
+    }
+
+    // a != 4 AND a = 3 -> a = 3
+    public void testPropagateEquals_VarNeq4AndVarEq3() {
+        FieldAttribute fa = getFieldAttribute();
+        NotEquals neq = new NotEquals(EMPTY, fa, FOUR);
+        Equals eq = new Equals(EMPTY, fa, THREE);
+
+        PropagateEquals rule = new PropagateEquals();
+        Expression exp = rule.rule(new And(EMPTY, neq, eq));
+        assertEquals(Equals.class, exp.getClass());
+        assertEquals(eq, rule.rule(exp));
+    }
+
+    // a = 2 AND a < 2 -> FALSE
+    public void testPropagateEquals_VarEq2AndVarLt2() {
+        FieldAttribute fa = getFieldAttribute();
+        Equals eq = new Equals(EMPTY, fa, TWO);
+        LessThan lt = new LessThan(EMPTY, fa, TWO);
+
+        PropagateEquals rule = new PropagateEquals();
+        Expression exp = rule.rule(new And(EMPTY, eq, lt));
+        assertEquals(FALSE, exp);
+    }
+
+    // a = 2 AND a <= 2 -> a = 2
+    public void testPropagateEquals_VarEq2AndVarLte2() {
+        FieldAttribute fa = getFieldAttribute();
+        Equals eq = new Equals(EMPTY, fa, TWO);
+        LessThanOrEqual lt = new LessThanOrEqual(EMPTY, fa, TWO);
+
+        PropagateEquals rule = new PropagateEquals();
+        Expression exp = rule.rule(new And(EMPTY, eq, lt));
+        assertEquals(eq, exp);
+    }
+
+    // a = 2 AND a <= 1 -> FALSE
+    public void testPropagateEquals_VarEq2AndVarLte1() {
+        FieldAttribute fa = getFieldAttribute();
+        Equals eq = new Equals(EMPTY, fa, TWO);
+        LessThanOrEqual lt = new LessThanOrEqual(EMPTY, fa, ONE);
+
+        PropagateEquals rule = new PropagateEquals();
+        Expression exp = rule.rule(new And(EMPTY, eq, lt));
+        assertEquals(FALSE, exp);
+    }
+
+    // a = 2 AND a > 2 -> FALSE
+    public void testPropagateEquals_VarEq2AndVarGt2() {
+        FieldAttribute fa = getFieldAttribute();
+        Equals eq = new Equals(EMPTY, fa, TWO);
+        GreaterThan gt = new GreaterThan(EMPTY, fa, TWO);
+
+        PropagateEquals rule = new PropagateEquals();
+        Expression exp = rule.rule(new And(EMPTY, eq, gt));
+        assertEquals(FALSE, exp);
+    }
+
+    // a = 2 AND a >= 2 -> a = 2
+    public void testPropagateEquals_VarEq2AndVarGte2() {
+        FieldAttribute fa = getFieldAttribute();
+        Equals eq = new Equals(EMPTY, fa, TWO);
+        GreaterThanOrEqual gte = new GreaterThanOrEqual(EMPTY, fa, TWO);
+
+        PropagateEquals rule = new PropagateEquals();
+        Expression exp = rule.rule(new And(EMPTY, eq, gte));
+        assertEquals(eq, exp);
+    }
+
+    // a = 2 AND a > 3 -> FALSE
+    public void testPropagateEquals_VarEq2AndVarLt3() {
+        FieldAttribute fa = getFieldAttribute();
+        Equals eq = new Equals(EMPTY, fa, TWO);
+        GreaterThan gt = new GreaterThan(EMPTY, fa, THREE);
+
+        PropagateEquals rule = new PropagateEquals();
+        Expression exp = rule.rule(new And(EMPTY, eq, gt));
+        assertEquals(FALSE, exp);
+    }
+
+    // a = 2 AND a < 3 AND a > 1 AND a != 4 -> a = 2
+    public void testPropagateEquals_VarEq2AndVarLt3AndVarGt1AndVarNeq4() {
+        FieldAttribute fa = getFieldAttribute();
+        Equals eq = new Equals(EMPTY, fa, TWO);
+        LessThan lt = new LessThan(EMPTY, fa, THREE);
+        GreaterThan gt = new GreaterThan(EMPTY, fa, ONE);
+        NotEquals neq = new NotEquals(EMPTY, fa, FOUR);
+
+        PropagateEquals rule = new PropagateEquals();
+        Expression and = Predicates.combineAnd(Arrays.asList(eq, lt, gt, neq));
+        Expression exp = rule.rule(and);
+        assertEquals(eq, exp);
+    }
+
+    // a = 2 AND 1 < a < 3 AND a > 0 AND a != 4 -> a = 2
+    public void testPropagateEquals_VarEq2AndVarRangeGt1Lt3AndVarGt0AndVarNeq4() {
+        FieldAttribute fa = getFieldAttribute();
+        Equals eq = new Equals(EMPTY, fa, TWO);
+        Range range = new Range(EMPTY, fa, ONE, false, THREE, false);
+        GreaterThan gt = new GreaterThan(EMPTY, fa, L(0));
+        NotEquals neq = new NotEquals(EMPTY, fa, FOUR);
+
+        PropagateEquals rule = new PropagateEquals();
+        Expression and = Predicates.combineAnd(Arrays.asList(eq, range, gt, neq));
+        Expression exp = rule.rule(and);
+        assertEquals(eq, exp);
+    }
+
+    // a = 2 OR a > 1 -> a > 1
+    public void testPropagateEquals_VarEq2OrVarGt1() {
+        FieldAttribute fa = getFieldAttribute();
+        Equals eq = new Equals(EMPTY, fa, TWO);
+        GreaterThan gt = new GreaterThan(EMPTY, fa, ONE);
+
+        PropagateEquals rule = new PropagateEquals();
+        Expression exp = rule.rule(new Or(EMPTY, eq, gt));
+        assertEquals(gt, exp);
+    }
+
+    // a = 2 OR a > 2 -> a >= 2
+    public void testPropagateEquals_VarEq2OrVarGte2() {
+        FieldAttribute fa = getFieldAttribute();
+        Equals eq = new Equals(EMPTY, fa, TWO);
+        GreaterThan gt = new GreaterThan(EMPTY, fa, TWO);
+
+        PropagateEquals rule = new PropagateEquals();
+        Expression exp = rule.rule(new Or(EMPTY, eq, gt));
+        assertEquals(GreaterThanOrEqual.class, exp.getClass());
+        GreaterThanOrEqual gte = (GreaterThanOrEqual) exp;
+        assertEquals(TWO, gte.right());
+    }
+
+    // a = 2 OR a < 3 -> a < 3
+    public void testPropagateEquals_VarEq2OrVarLt3() {
+        FieldAttribute fa = getFieldAttribute();
+        Equals eq = new Equals(EMPTY, fa, TWO);
+        LessThan lt = new LessThan(EMPTY, fa, THREE);
+
+        PropagateEquals rule = new PropagateEquals();
+        Expression exp = rule.rule(new Or(EMPTY, eq, lt));
+        assertEquals(lt, exp);
+    }
+
+    // a = 3 OR a < 3 -> a <= 3
+    public void testPropagateEquals_VarEq3OrVarLt3() {
+        FieldAttribute fa = getFieldAttribute();
+        Equals eq = new Equals(EMPTY, fa, THREE);
+        LessThan lt = new LessThan(EMPTY, fa, THREE);
+
+        PropagateEquals rule = new PropagateEquals();
+        Expression exp = rule.rule(new Or(EMPTY, eq, lt));
+        assertEquals(LessThanOrEqual.class, exp.getClass());
+        LessThanOrEqual lte = (LessThanOrEqual) exp;
+        assertEquals(THREE, lte.right());
+    }
+
+    // a = 2 OR 1 < a < 3 -> 1 < a < 3
+    public void testPropagateEquals_VarEq2OrVarRangeGt1Lt3() {
+        FieldAttribute fa = getFieldAttribute();
+        Equals eq = new Equals(EMPTY, fa, TWO);
+        Range range = new Range(EMPTY, fa, ONE, false, THREE, false);
+
+        PropagateEquals rule = new PropagateEquals();
+        Expression exp = rule.rule(new Or(EMPTY, eq, range));
+        assertEquals(range, exp);
+    }
+
+    // a = 2 OR 2 < a < 3 -> 2 <= a < 3
+    public void testPropagateEquals_VarEq2OrVarRangeGt2Lt3() {
+        FieldAttribute fa = getFieldAttribute();
+        Equals eq = new Equals(EMPTY, fa, TWO);
+        Range range = new Range(EMPTY, fa, TWO, false, THREE, false);
+
+        PropagateEquals rule = new PropagateEquals();
+        Expression exp = rule.rule(new Or(EMPTY, eq, range));
+        assertEquals(Range.class, exp.getClass());
+        Range r = (Range) exp;
+        assertEquals(TWO, r.lower());
+        assertTrue(r.includeLower());
+        assertEquals(THREE, r.upper());
+        assertFalse(r.includeUpper());
+    }
+
+    // a = 3 OR 2 < a < 3 -> 2 < a <= 3
+    public void testPropagateEquals_VarEq3OrVarRangeGt2Lt3() {
+        FieldAttribute fa = getFieldAttribute();
+        Equals eq = new Equals(EMPTY, fa, THREE);
+        Range range = new Range(EMPTY, fa, TWO, false, THREE, false);
+
+        PropagateEquals rule = new PropagateEquals();
+        Expression exp = rule.rule(new Or(EMPTY, eq, range));
+        assertEquals(Range.class, exp.getClass());
+        Range r = (Range) exp;
+        assertEquals(TWO, r.lower());
+        assertFalse(r.includeLower());
+        assertEquals(THREE, r.upper());
+        assertTrue(r.includeUpper());
+    }
+
+    // a = 2 OR a != 2 -> TRUE
+    public void testPropagateEquals_VarEq2OrVarNeq2() {
+        FieldAttribute fa = getFieldAttribute();
+        Equals eq = new Equals(EMPTY, fa, TWO);
+        NotEquals neq = new NotEquals(EMPTY, fa, TWO);
+
+        PropagateEquals rule = new PropagateEquals();
+        Expression exp = rule.rule(new Or(EMPTY, eq, neq));
+        assertEquals(TRUE, exp);
+    }
+
+    // a = 2 OR a != 5 -> a != 5
+    public void testPropagateEquals_VarEq2OrVarNeq5() {
+        FieldAttribute fa = getFieldAttribute();
+        Equals eq = new Equals(EMPTY, fa, TWO);
+        NotEquals neq = new NotEquals(EMPTY, fa, FIVE);
+
+        PropagateEquals rule = new PropagateEquals();
+        Expression exp = rule.rule(new Or(EMPTY, eq, neq));
+        assertEquals(NotEquals.class, exp.getClass());
+        NotEquals ne = (NotEquals) exp;
+        assertEquals(ne.right(), FIVE);
+    }
+
+    // a = 2 OR 3 < a < 4 OR a > 2 OR a!= 2 -> TRUE
+    public void testPropagateEquals_VarEq2OrVarRangeGt3Lt4OrVarGt2OrVarNe2() {
+        FieldAttribute fa = getFieldAttribute();
+        Equals eq = new Equals(EMPTY, fa, TWO);
+        Range range = new Range(EMPTY, fa, THREE, false, FOUR, false);
+        GreaterThan gt = new GreaterThan(EMPTY, fa, TWO);
+        NotEquals neq = new NotEquals(EMPTY, fa, TWO);
+
+        PropagateEquals rule = new PropagateEquals();
+        Expression exp = rule.rule(Predicates.combineOr(Arrays.asList(eq, range, neq, gt)));
+        assertEquals(TRUE, exp);
+    }
+
     public void testTranslateMinToFirst() {
         Min min1 =  new Min(EMPTY, new FieldAttribute(EMPTY, "str", new EsField("str", DataType.KEYWORD, emptyMap(), true)));
         Min min2 =  new Min(EMPTY, getFieldAttribute());