Просмотр исходного кода

EQL: Propagate key constraints through the query (#62073)

Since join keys are common across all queries in a Join/Sequence, any
constraint applied on one query needs to be obeyed but all the other
queries.
This PR enhances the optimizer to propagate such constraints across
all queries so they get pushed down to the actual generated ES queries.

Fix #58937
Costin Leau 5 лет назад
Родитель
Сommit
4afa5debc1

+ 127 - 4
x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/optimizer/Optimizer.java

@@ -16,10 +16,14 @@ import org.elasticsearch.xpack.eql.util.MathUtils;
 import org.elasticsearch.xpack.eql.util.StringUtils;
 import org.elasticsearch.xpack.eql.util.StringUtils;
 import org.elasticsearch.xpack.ql.expression.Expression;
 import org.elasticsearch.xpack.ql.expression.Expression;
 import org.elasticsearch.xpack.ql.expression.Literal;
 import org.elasticsearch.xpack.ql.expression.Literal;
+import org.elasticsearch.xpack.ql.expression.NamedExpression;
 import org.elasticsearch.xpack.ql.expression.Order;
 import org.elasticsearch.xpack.ql.expression.Order;
 import org.elasticsearch.xpack.ql.expression.Order.NullsPosition;
 import org.elasticsearch.xpack.ql.expression.Order.NullsPosition;
 import org.elasticsearch.xpack.ql.expression.Order.OrderDirection;
 import org.elasticsearch.xpack.ql.expression.Order.OrderDirection;
+import org.elasticsearch.xpack.ql.expression.predicate.Predicates;
+import org.elasticsearch.xpack.ql.expression.predicate.logical.And;
 import org.elasticsearch.xpack.ql.expression.predicate.logical.Not;
 import org.elasticsearch.xpack.ql.expression.predicate.logical.Not;
+import org.elasticsearch.xpack.ql.expression.predicate.logical.Or;
 import org.elasticsearch.xpack.ql.expression.predicate.nulls.IsNotNull;
 import org.elasticsearch.xpack.ql.expression.predicate.nulls.IsNotNull;
 import org.elasticsearch.xpack.ql.expression.predicate.nulls.IsNull;
 import org.elasticsearch.xpack.ql.expression.predicate.nulls.IsNull;
 import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparison;
 import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparison;
@@ -51,6 +55,7 @@ import java.util.Arrays;
 import java.util.List;
 import java.util.List;
 
 
 import static java.util.Collections.singletonList;
 import static java.util.Collections.singletonList;
+import static java.util.stream.Collectors.toList;
 
 
 public class Optimizer extends RuleExecutor<LogicalPlan> {
 public class Optimizer extends RuleExecutor<LogicalPlan> {
 
 
@@ -74,11 +79,15 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
                 new ReplaceNullChecks(),
                 new ReplaceNullChecks(),
                 new PropagateEquals(),
                 new PropagateEquals(),
                 new CombineBinaryComparisons(),
                 new CombineBinaryComparisons(),
+                new PushDownAndCombineFilters(),
                 // prune/elimination
                 // prune/elimination
                 new PruneFilters(),
                 new PruneFilters(),
                 new PruneLiteralsInOrderBy(),
                 new PruneLiteralsInOrderBy(),
                 new CombineLimits());
                 new CombineLimits());
 
 
+        Batch constraints = new Batch("Infer constraints", Limiter.ONCE,
+                new PropagateJoinKeyConstraints());
+
         Batch ordering = new Batch("Implicit Order",
         Batch ordering = new Batch("Implicit Order",
                 new SortByLimit(),
                 new SortByLimit(),
                 new PushDownOrderBy());
                 new PushDownOrderBy());
@@ -91,9 +100,9 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
         Batch label = new Batch("Set as Optimized", Limiter.ONCE,
         Batch label = new Batch("Set as Optimized", Limiter.ONCE,
                 new SetAsOptimized());
                 new SetAsOptimized());
 
 
-        return Arrays.asList(substitutions, operators, ordering, local, label);
+        return Arrays.asList(substitutions, operators, constraints, operators, ordering, local, label);
     }
     }
-    
+
     private static class ReplaceWildcards extends OptimizerRule<Filter> {
     private static class ReplaceWildcards extends OptimizerRule<Filter> {
 
 
         private static boolean isWildcard(Expression expr) {
         private static boolean isWildcard(Expression expr) {
@@ -152,6 +161,25 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
         }
         }
     }
     }
 
 
+    static class PushDownAndCombineFilters extends OptimizerRule<Filter> {
+
+        @Override
+        protected LogicalPlan rule(Filter filter) {
+            LogicalPlan child = filter.child();
+            LogicalPlan plan = filter;
+
+            if (child instanceof Filter) {
+                Filter f = (Filter) child;
+                plan = new Filter(f.source(), f.child(), new And(f.source(), f.condition(), filter.condition()));
+            } else if (child instanceof UnaryPlan) {
+                UnaryPlan up = (UnaryPlan) child;
+                plan = child.replaceChildren(singletonList(new Filter(filter.source(), up.child(), filter.condition())));
+            }
+
+            return plan;
+        }
+    }
+
     static class PruneFilters extends org.elasticsearch.xpack.ql.optimizer.OptimizerRules.PruneFilters {
     static class PruneFilters extends org.elasticsearch.xpack.ql.optimizer.OptimizerRules.PruneFilters {
 
 
         @Override
         @Override
@@ -237,6 +265,101 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
         }
         }
     }
     }
 
 
+
+    /**
+     * Any condition applied on a join/sequence key, gets propagated to all rules.
+     */
+    static class PropagateJoinKeyConstraints extends OptimizerRule<Join> {
+
+        class Constraint {
+            private final Expression condition;
+            private final KeyedFilter keyedFilter;
+            private final int keyPosition;
+
+            Constraint(Expression condition, KeyedFilter filter, int keyPosition) {
+                this.condition = condition;
+                this.keyedFilter = filter;
+                this.keyPosition = keyPosition;
+            }
+
+            Expression constraintFor(KeyedFilter keyed) {
+                if (keyed == keyedFilter) {
+                    return null;
+                }
+
+                Expression localKey = keyed.keys().get(keyPosition);
+                Expression key = keyedFilter.keys().get(keyPosition);
+
+                Expression newCond = condition.transformDown(e -> key.semanticEquals(e) ? localKey : e);
+                return newCond;
+            }
+
+            @Override
+            public String toString() {
+                return condition.toString();
+            }
+        }
+
+        @Override
+        protected LogicalPlan rule(Join join) {
+            List<Constraint> constraints = new ArrayList<>();
+
+            // collect constraints for each filter
+            join.queries().forEach(k ->
+                k.forEachDown(f -> constraints.addAll(detectKeyConstraints(f.condition(), k))
+                                  , Filter.class));
+
+            if (constraints.isEmpty() == false) {
+                List<KeyedFilter> queries = join.queries().stream()
+                        .map(k -> addConstraint(k, constraints))
+                        .collect(toList());
+
+                join = join.with(queries, join.until(), join.direction());
+            }
+
+            return join;
+        }
+
+        private List<Constraint> detectKeyConstraints(Expression condition, KeyedFilter filter) {
+            List<Constraint> constraints = new ArrayList<>();
+            List<? extends NamedExpression> keys = filter.keys();
+
+            List<Expression> and = Predicates.splitAnd(condition);
+            for (Expression exp : and) {
+                // if there are no conjunction and at least one key matches, save the expression along with the key
+                // and its ordinal so it can be replaced
+                if (exp.anyMatch(Or.class::isInstance) == false) {
+                    // comparisons against variables are not done
+                    // hence why on the first key match, the expression is picked up
+                    exp.anyMatch(e -> {
+                        for (int i = 0; i < keys.size(); i++) {
+                            Expression key = keys.get(i);
+                            if (e.semanticEquals(key)) {
+                                constraints.add(new Constraint(exp, filter, i));
+                                return true;
+                            }
+                        }
+                        return false;
+                    });
+                }
+            }
+            return constraints;
+        }
+
+        // adapt constraint to the given filter by replacing the keys accordingly in the expressions
+        private KeyedFilter addConstraint(KeyedFilter k, List<Constraint> constraints) {
+            Expression constraint = Predicates.combineAnd(constraints.stream()
+                .map(c -> c.constraintFor(k))
+                .filter(c -> c != null)
+                .collect(toList()));
+
+            return constraint != null
+                    ? new KeyedFilter(k.source(), new Filter(k.source(), k.child(), constraint), k.keys(), k.timestamp(), k.tiebreaker())
+                    : k;
+        }
+    }
+
+
     /**
     /**
      * Align the implicit order with the limit (head means ASC or tail means DESC).
      * Align the implicit order with the limit (head means ASC or tail means DESC).
      */
      */
@@ -256,7 +379,7 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
                     }
                     }
                 }
                 }
             }
             }
-            
+
             return limit;
             return limit;
         }
         }
     }
     }
@@ -341,4 +464,4 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
             return plan;
             return plan;
         }
         }
     }
     }
-}
+}

+ 425 - 27
x-pack/plugin/eql/src/test/java/org/elasticsearch/xpack/eql/optimizer/OptimizerTests.java

@@ -22,15 +22,21 @@ import org.elasticsearch.xpack.eql.plan.physical.LocalRelation;
 import org.elasticsearch.xpack.eql.stats.Metrics;
 import org.elasticsearch.xpack.eql.stats.Metrics;
 import org.elasticsearch.xpack.ql.expression.Attribute;
 import org.elasticsearch.xpack.ql.expression.Attribute;
 import org.elasticsearch.xpack.ql.expression.EmptyAttribute;
 import org.elasticsearch.xpack.ql.expression.EmptyAttribute;
+import org.elasticsearch.xpack.ql.expression.Expression;
 import org.elasticsearch.xpack.ql.expression.FieldAttribute;
 import org.elasticsearch.xpack.ql.expression.FieldAttribute;
 import org.elasticsearch.xpack.ql.expression.Literal;
 import org.elasticsearch.xpack.ql.expression.Literal;
+import org.elasticsearch.xpack.ql.expression.NamedExpression;
 import org.elasticsearch.xpack.ql.expression.Order;
 import org.elasticsearch.xpack.ql.expression.Order;
 import org.elasticsearch.xpack.ql.expression.Order.NullsPosition;
 import org.elasticsearch.xpack.ql.expression.Order.NullsPosition;
 import org.elasticsearch.xpack.ql.expression.Order.OrderDirection;
 import org.elasticsearch.xpack.ql.expression.Order.OrderDirection;
 import org.elasticsearch.xpack.ql.expression.predicate.logical.And;
 import org.elasticsearch.xpack.ql.expression.predicate.logical.And;
 import org.elasticsearch.xpack.ql.expression.predicate.logical.Not;
 import org.elasticsearch.xpack.ql.expression.predicate.logical.Not;
+import org.elasticsearch.xpack.ql.expression.predicate.logical.Or;
 import org.elasticsearch.xpack.ql.expression.predicate.nulls.IsNotNull;
 import org.elasticsearch.xpack.ql.expression.predicate.nulls.IsNotNull;
 import org.elasticsearch.xpack.ql.expression.predicate.nulls.IsNull;
 import org.elasticsearch.xpack.ql.expression.predicate.nulls.IsNull;
+import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.Equals;
+import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.GreaterThan;
+import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.LessThan;
 import org.elasticsearch.xpack.ql.expression.predicate.regex.Like;
 import org.elasticsearch.xpack.ql.expression.predicate.regex.Like;
 import org.elasticsearch.xpack.ql.index.EsIndex;
 import org.elasticsearch.xpack.ql.index.EsIndex;
 import org.elasticsearch.xpack.ql.index.IndexResolution;
 import org.elasticsearch.xpack.ql.index.IndexResolution;
@@ -41,20 +47,25 @@ import org.elasticsearch.xpack.ql.plan.logical.OrderBy;
 import org.elasticsearch.xpack.ql.plan.logical.Project;
 import org.elasticsearch.xpack.ql.plan.logical.Project;
 import org.elasticsearch.xpack.ql.plan.logical.UnaryPlan;
 import org.elasticsearch.xpack.ql.plan.logical.UnaryPlan;
 import org.elasticsearch.xpack.ql.plan.logical.UnresolvedRelation;
 import org.elasticsearch.xpack.ql.plan.logical.UnresolvedRelation;
-import org.elasticsearch.xpack.ql.type.DataTypes;
 import org.elasticsearch.xpack.ql.type.EsField;
 import org.elasticsearch.xpack.ql.type.EsField;
 import org.elasticsearch.xpack.ql.type.TypesTests;
 import org.elasticsearch.xpack.ql.type.TypesTests;
 
 
+import java.time.ZoneId;
 import java.util.Arrays;
 import java.util.Arrays;
 import java.util.List;
 import java.util.List;
 import java.util.Map;
 import java.util.Map;
+import java.util.stream.Stream;
 
 
 import static java.util.Arrays.asList;
 import static java.util.Arrays.asList;
 import static java.util.Collections.emptyList;
 import static java.util.Collections.emptyList;
 import static java.util.Collections.emptyMap;
 import static java.util.Collections.emptyMap;
 import static java.util.Collections.singletonList;
 import static java.util.Collections.singletonList;
+import static java.util.stream.Collectors.toList;
 import static org.elasticsearch.xpack.eql.EqlTestUtils.TEST_CFG_CASE_INSENSITIVE;
 import static org.elasticsearch.xpack.eql.EqlTestUtils.TEST_CFG_CASE_INSENSITIVE;
+import static org.elasticsearch.xpack.ql.TestUtils.UTC;
+import static org.elasticsearch.xpack.ql.expression.Literal.TRUE;
 import static org.elasticsearch.xpack.ql.tree.Source.EMPTY;
 import static org.elasticsearch.xpack.ql.tree.Source.EMPTY;
+import static org.elasticsearch.xpack.ql.type.DataTypes.INTEGER;
 
 
 public class OptimizerTests extends ESTestCase {
 public class OptimizerTests extends ESTestCase {
 
 
@@ -77,7 +88,7 @@ public class OptimizerTests extends ESTestCase {
         PostAnalyzer postAnalyzer = new PostAnalyzer();
         PostAnalyzer postAnalyzer = new PostAnalyzer();
         Analyzer analyzer = new Analyzer(TEST_CFG_CASE_INSENSITIVE, new EqlFunctionRegistry(), new Verifier(new Metrics()));
         Analyzer analyzer = new Analyzer(TEST_CFG_CASE_INSENSITIVE, new EqlFunctionRegistry(), new Verifier(new Metrics()));
         return optimizer.optimize(postAnalyzer.postAnalyze(analyzer.analyze(preAnalyzer.preAnalyze(parser.createStatement(eql),
         return optimizer.optimize(postAnalyzer.postAnalyze(analyzer.analyze(preAnalyzer.preAnalyze(parser.createStatement(eql),
-                resolution)), TEST_CFG_CASE_INSENSITIVE));
+            resolution)), TEST_CFG_CASE_INSENSITIVE));
     }
     }
 
 
     private LogicalPlan accept(String eql) {
     private LogicalPlan accept(String eql) {
@@ -102,6 +113,7 @@ public class OptimizerTests extends ESTestCase {
             assertEquals(((FieldAttribute) check.field()).name(), "command_line");
             assertEquals(((FieldAttribute) check.field()).name(), "command_line");
         }
         }
     }
     }
+
     public void testIsNotNull() {
     public void testIsNotNull() {
         List<String> tests = Arrays.asList(
         List<String> tests = Arrays.asList(
             "foo where command_line != null",
             "foo where command_line != null",
@@ -231,26 +243,10 @@ public class OptimizerTests extends ESTestCase {
         assertEquals("Incorrect limit", limit, lo.limit().fold());
         assertEquals("Incorrect limit", limit, lo.limit().fold());
     }
     }
 
 
-    private static Attribute timestamp() {
-        return new FieldAttribute(EMPTY, "test", new EsField("field", DataTypes.INTEGER, emptyMap(), true));
-    }
-
-    private static Attribute tiebreaker() {
-        return new EmptyAttribute(EMPTY);
-    }
-
-    private static LogicalPlan rel() {
-        return new UnresolvedRelation(EMPTY, new TableIdentifier(EMPTY, "catalog", "index"), "", false);
-    }
-
-    private static KeyedFilter keyedFilter(LogicalPlan child) {
-        return new KeyedFilter(EMPTY, child, emptyList(), timestamp(), tiebreaker());
-    }
-
     public void testSkipQueryOnLimitZero() {
     public void testSkipQueryOnLimitZero() {
         KeyedFilter rule1 = keyedFilter(new LocalRelation(EMPTY, emptyList()));
         KeyedFilter rule1 = keyedFilter(new LocalRelation(EMPTY, emptyList()));
-        KeyedFilter rule2 = keyedFilter(new Filter(EMPTY, rel(), new IsNull(EMPTY, Literal.TRUE)));
-        KeyedFilter until = keyedFilter(new Filter(EMPTY, rel(), Literal.FALSE));
+        KeyedFilter rule2 = keyedFilter(basicFilter(new IsNull(EMPTY, TRUE)));
+        KeyedFilter until = keyedFilter(basicFilter(Literal.FALSE));
         Sequence s = new Sequence(EMPTY, asList(rule1, rule2), until, TimeValue.MINUS_ONE, timestamp(), tiebreaker(), OrderDirection.ASC);
         Sequence s = new Sequence(EMPTY, asList(rule1, rule2), until, TimeValue.MINUS_ONE, timestamp(), tiebreaker(), OrderDirection.ASC);
 
 
         LogicalPlan optimized = optimizer.optimize(s);
         LogicalPlan optimized = optimizer.optimize(s);
@@ -260,7 +256,7 @@ public class OptimizerTests extends ESTestCase {
     public void testSortByLimit() {
     public void testSortByLimit() {
         Project p = new Project(EMPTY, rel(), emptyList());
         Project p = new Project(EMPTY, rel(), emptyList());
         OrderBy o = new OrderBy(EMPTY, p, singletonList(new Order(EMPTY, tiebreaker(), OrderDirection.ASC, NullsPosition.FIRST)));
         OrderBy o = new OrderBy(EMPTY, p, singletonList(new Order(EMPTY, tiebreaker(), OrderDirection.ASC, NullsPosition.FIRST)));
-        Tail t = new Tail(EMPTY, new Literal(EMPTY, 1, DataTypes.INTEGER), o);
+        Tail t = new Tail(EMPTY, new Literal(EMPTY, 1, INTEGER), o);
 
 
         LogicalPlan optimized = new Optimizer.SortByLimit().rule(t);
         LogicalPlan optimized = new Optimizer.SortByLimit().rule(t);
         assertEquals(LimitWithOffset.class, optimized.getClass());
         assertEquals(LimitWithOffset.class, optimized.getClass());
@@ -269,11 +265,10 @@ public class OptimizerTests extends ESTestCase {
     }
     }
 
 
     public void testPushdownOrderBy() {
     public void testPushdownOrderBy() {
-        Filter filter = new Filter(EMPTY, rel(), new IsNull(EMPTY, Literal.TRUE));
+        Filter filter = basicFilter(new IsNull(EMPTY, TRUE));
         KeyedFilter rule1 = keyedFilter(filter);
         KeyedFilter rule1 = keyedFilter(filter);
         KeyedFilter rule2 = keyedFilter(filter);
         KeyedFilter rule2 = keyedFilter(filter);
-        KeyedFilter until = keyedFilter(filter);
-        Sequence s = new Sequence(EMPTY, asList(rule1, rule2), until, TimeValue.MINUS_ONE, timestamp(), tiebreaker(), OrderDirection.ASC);
+        Sequence s = sequence(rule1, rule2);
         OrderBy o = new OrderBy(EMPTY, s, singletonList(new Order(EMPTY, tiebreaker(), OrderDirection.DESC, NullsPosition.FIRST)));
         OrderBy o = new OrderBy(EMPTY, s, singletonList(new Order(EMPTY, tiebreaker(), OrderDirection.DESC, NullsPosition.FIRST)));
 
 
         LogicalPlan optimized = new Optimizer.PushDownOrderBy().rule(o);
         LogicalPlan optimized = new Optimizer.PushDownOrderBy().rule(o);
@@ -285,14 +280,391 @@ public class OptimizerTests extends ESTestCase {
         assertOrder(seq.queries().get(1), OrderDirection.ASC);
         assertOrder(seq.queries().get(1), OrderDirection.ASC);
     }
     }
 
 
-    private void assertOrder(UnaryPlan plan, OrderDirection direction) {
+    /**
+     * Filter X
+     * Filter Y
+     * ==
+     * Filter X and Y
+     */
+    public void testCombineFilters() {
+        Expression left = new IsNull(EMPTY, TRUE);
+        Expression right = equalsExpression();
+
+        Filter filterChild = basicFilter(left);
+        Filter filterParent = new Filter(EMPTY, filterChild, right);
+
+        LogicalPlan result = new Optimizer.PushDownAndCombineFilters().apply(filterParent);
+
+        assertEquals(Filter.class, result.getClass());
+        Expression condition = ((Filter) result).condition();
+        assertEquals(And.class, condition.getClass());
+        And and = (And) condition;
+        assertEquals(left, and.left());
+        assertEquals(right, and.right());
+    }
+
+    /**
+     * Filter X
+     * UnaryNode
+     * LeafNode
+     * ==
+     * UnaryNode
+     * Filter X
+     * LeafNode
+     */
+    public void testPushDownFilterUnary() {
+        Expression left = new IsNull(EMPTY, TRUE);
+
+        OrderBy order = new OrderBy(EMPTY, rel(), emptyList());
+        Filter filter = new Filter(EMPTY, order, left);
+
+        LogicalPlan result = new Optimizer.PushDownAndCombineFilters().apply(filter);
+
+        assertEquals(OrderBy.class, result.getClass());
+        OrderBy o = (OrderBy) result;
+        assertEquals(Filter.class, o.child().getClass());
+        Filter f = (Filter) o.child();
+
+        assertEquals(rel(), f.child());
+        assertEquals(filter.condition(), f.condition());
+    }
+
+    /**
+     * Filter
+     * LeafNode
+     * ==
+     * Filter
+     * LeafNode
+     */
+    public void testPushDownFilterDoesNotApplyOnNonUnary() {
+        Expression left = new IsNull(EMPTY, TRUE);
+
+        KeyedFilter rule1 = keyedFilter(new LocalRelation(EMPTY, emptyList()));
+        KeyedFilter rule2 = keyedFilter(basicFilter(new IsNull(EMPTY, TRUE)));
+
+        Sequence s = sequence(rule1, rule2);
+        Filter filter = new Filter(EMPTY, s, left);
+
+        LogicalPlan result = new Optimizer.PushDownAndCombineFilters().apply(filter);
+
+        assertEquals(Filter.class, result.getClass());
+        Filter f = (Filter) result;
+        assertEquals(s, f.child());
+    }
+
+    /**
+     * sequence
+     * 1. filter a gt 1 by a
+     * 2. filter X by a
+     * ==
+     * sequence
+     * 1. filter a gt 1 by a
+     * 2. filter a gt 1 by a
+     * \filter X
+     */
+    public void testKeySameConstraints() {
+        ZoneId zd = randomZone();
+        Attribute a = key("a");
+
+        Expression keyCondition = gtExpression(a);
+        Expression filter = equalsExpression();
+
+        KeyedFilter rule1 = keyedFilter(basicFilter(keyCondition), a);
+        KeyedFilter rule2 = keyedFilter(basicFilter(filter), a);
+
+        Sequence s = sequence(rule1, rule2);
+
+        LogicalPlan result = new Optimizer.PropagateJoinKeyConstraints().apply(s);
+
+        assertEquals(Sequence.class, result.getClass());
+        Sequence seq = (Sequence) result;
+
+        List<KeyedFilter> queries = seq.queries();
+        assertEquals(rule1, queries.get(0));
+        KeyedFilter query2 = queries.get(1);
+        assertEquals(keyCondition, filterCondition(query2.child()));
+        assertEquals(filter, filterCondition(query2.child().children().get(0)));
+    }
+
+    /**
+     * sequence
+     * 1. filter a gt 1 by a
+     * 2. filter b == true by b
+     * ==
+     * sequence
+     * 1. filter a == true by a
+     * \filter a gt 1
+     * 2. filter b gt 1 by b
+     * \filter b == true
+     */
+    public void testSameTwoKeysConstraints() {
+        Attribute a = key("a");
+        Attribute b = key("b");
+
+        Expression keyACondition = gtExpression(a);
+        Expression keyBCondition = new Equals(EMPTY, b, TRUE);
+
+        KeyedFilter rule1 = keyedFilter(basicFilter(keyACondition), a, b);
+        KeyedFilter rule2 = keyedFilter(basicFilter(keyBCondition), a, b);
+
+        Sequence s = sequence(rule1, rule2);
+
+        LogicalPlan result = new Optimizer.PropagateJoinKeyConstraints().apply(s);
+
+        assertEquals(Sequence.class, result.getClass());
+        Sequence seq = (Sequence) result;
+
+        List<KeyedFilter> queries = seq.queries();
+        KeyedFilter query1 = queries.get(0);
+        assertEquals(keyBCondition, filterCondition(query1.child()));
+        assertEquals(keyACondition, filterCondition(query1.child().children().get(0)));
+
+        KeyedFilter query2 = queries.get(1);
+        assertEquals(keyACondition, filterCondition(query2.child()));
+        assertEquals(keyBCondition, filterCondition(query2.child().children().get(0)));
+    }
+
+    /**
+     * sequence
+     * 1. filter a gt 1 by a
+     * 2. filter b == 1 by b
+     * ==
+     * sequence
+     * 1. filter a == 1 by a
+     * \filter a gt 1
+     * 2. filter b gt 1 by b
+     * \filter b == 1
+     */
+    public void testDifferentOneKeyConstraints() {
+        ZoneId zd = randomZone();
+        Attribute a = key("a");
+        Attribute b = key("b");
+
+        Expression keyARuleACondition = gtExpression(a);
+        Expression keyBRuleACondition = gtExpression(b);
+
+        Expression keyARuleBCondition = new Equals(EMPTY, a, TRUE);
+        Expression keyBRuleBCondition = new Equals(EMPTY, b, TRUE);
+
+        KeyedFilter rule1 = keyedFilter(basicFilter(keyARuleACondition), a);
+        KeyedFilter rule2 = keyedFilter(basicFilter(keyBRuleBCondition), b);
+
+        Sequence s = sequence(rule1, rule2);
+
+        LogicalPlan result = new Optimizer.PropagateJoinKeyConstraints().apply(s);
+
+        assertEquals(Sequence.class, result.getClass());
+        Sequence seq = (Sequence) result;
+
+        List<KeyedFilter> queries = seq.queries();
+        KeyedFilter query1 = queries.get(0);
+
+        assertEquals(keyARuleBCondition, filterCondition(query1.child()));
+        assertEquals(keyARuleACondition, filterCondition(query1.child().children().get(0)));
+
+        KeyedFilter query2 = queries.get(1);
+        assertEquals(keyBRuleACondition, filterCondition(query2.child()));
+        assertEquals(keyBRuleBCondition, filterCondition(query2.child().children().get(0)));
+    }
+
+    /**
+     * sequence
+     * 1. filter a1 gt 1 and a2 lt 1 by a1, a2
+     * 2. filter someKey == true by b1, b2
+     * ==
+     * sequence
+     * 1. filter a1 gt 1 and a2 gt 1 by a1, a2
+     * 2. filter b1 gt 1 and b2 gt 1 by b1, b2
+     * \filter someKey == true
+     */
+    public void testQueryLevelTwoKeyConstraints() {
+        ZoneId zd = randomZone();
+        Attribute a1 = key("a1");
+        Attribute a2 = key("a2");
+
+        Attribute b1 = key("b1");
+        Attribute b2 = key("b2");
+
+        Expression keyA1RuleACondition = gtExpression(a1);
+        Expression keyA2RuleACondition = new LessThan(EMPTY, a2, new Literal(EMPTY, 1, INTEGER), zd);
+        Expression ruleACondition = new And(EMPTY, keyA1RuleACondition, keyA2RuleACondition);
+
+        Expression ruleBCondition = new Equals(EMPTY, key("someKey"), TRUE);
+
+        KeyedFilter rule1 = keyedFilter(basicFilter(ruleACondition), a1, a2);
+        KeyedFilter rule2 = keyedFilter(basicFilter(ruleBCondition), b1, b2);
+
+        Sequence s = sequence(rule1, rule2);
+
+        LogicalPlan result = new Optimizer.PropagateJoinKeyConstraints().apply(s);
+
+        assertEquals(Sequence.class, result.getClass());
+        Sequence seq = (Sequence) result;
+
+        List<KeyedFilter> queries = seq.queries();
+        KeyedFilter query1 = queries.get(0);
+
+        assertEquals(rule1, query1);
+
+        KeyedFilter query2 = queries.get(1);
+        // rewrite constraints for key B
+        Expression keyB1RuleACondition = gtExpression(b1);
+        Expression keyB2RuleACondition = new LessThan(EMPTY, b2, new Literal(EMPTY, 1, INTEGER), zd);
+
+        assertEquals(new And(EMPTY, keyB1RuleACondition, keyB2RuleACondition), filterCondition(query2.child()));
+        assertEquals(ruleBCondition, filterCondition(query2.child().children().get(0)));
+    }
+
+    /**
+     * Key conditions inside a disjunction (OR) are ignored
+     * <p>
+     * sequence
+     * 1. filter a gt 1 OR x == 1 by a
+     * 2. filter x == 1 by b
+     * ==
+     * same
+     */
+    public void testSkipKeySameWithDisjunctionConstraints() {
+        ZoneId zd = randomZone();
+        Attribute a = key("a");
+
+        Expression keyCondition = gtExpression(a);
+        Expression filter = equalsExpression();
+        Expression cond = new Or(EMPTY, filter, keyCondition);
+
+        KeyedFilter rule1 = keyedFilter(basicFilter(cond), a);
+        KeyedFilter rule2 = keyedFilter(basicFilter(filter), a);
+
+        Sequence s = sequence(rule1, rule2);
+
+        LogicalPlan result = new Optimizer.PropagateJoinKeyConstraints().apply(s);
+
+        assertEquals(Sequence.class, result.getClass());
+        Sequence seq = (Sequence) result;
+
+        List<KeyedFilter> queries = seq.queries();
+        assertEquals(rule1, queries.get(0));
+        assertEquals(rule2, queries.get(1));
+    }
+
+    /**
+     * Key conditions inside a conjunction (AND) are picked up
+     * <p>
+     * sequence
+     * 1. filter a gt 1 and x == 1 by a
+     * 2. filter x == 1 by b
+     * ==
+     * sequence
+     * 1. filter a gt 1 and x == 1 by a
+     * 2. filter b gt 1 by b
+     * \filter x == 1
+     */
+    public void testExtractKeySameFromDisjunction() {
+        ZoneId zd = randomZone();
+        Attribute a = key("a");
+
+        Expression keyCondition = gtExpression(a);
+        Expression filter = equalsExpression();
+
+        Expression cond = new And(EMPTY, filter, keyCondition);
+
+        KeyedFilter rule1 = keyedFilter(basicFilter(cond), a);
+        KeyedFilter rule2 = keyedFilter(basicFilter(filter), a);
+
+        Sequence s = sequence(rule1, rule2);
+
+        LogicalPlan result = new Optimizer.PropagateJoinKeyConstraints().apply(s);
+
+        assertEquals(Sequence.class, result.getClass());
+        Sequence seq = (Sequence) result;
+
+        List<KeyedFilter> queries = seq.queries();
+        assertEquals(rule1, queries.get(0));
+
+        KeyedFilter query2 = queries.get(1);
+        LogicalPlan child2 = query2.child();
+
+        Expression keyRuleBCondition = gtExpression(a);
+
+        assertEquals(keyRuleBCondition, filterCondition(child2));
+        assertEquals(filter, filterCondition(child2.children().get(0)));
+    }
+
+    /**
+     * Multiple key conditions inside a conjunction (AND) are picked up
+     * <p>
+     * sequence
+     * 1. filter a gt 1 and x by a
+     * 2. filter x by b
+     * =
+     * sequence
+     * 1. filter a gt 1 and x by a
+     * 2. filter b gt 1 by b
+     * \filter x
+     */
+    public void testDifferentKeyFromDisjunction() {
+        Attribute a = key("a");
+        Attribute b = key("b");
+
+        Expression keyARuleACondition = gtExpression(a);
+        Expression filter = equalsExpression();
+
+        Expression cond = new And(EMPTY, filter, new And(EMPTY, keyARuleACondition, filter));
+
+        KeyedFilter rule1 = keyedFilter(basicFilter(cond), a);
+        KeyedFilter rule2 = keyedFilter(basicFilter(filter), b);
+
+        Sequence s = sequence(rule1, rule2);
+
+        LogicalPlan result = new Optimizer.PropagateJoinKeyConstraints().apply(s);
+
+        assertEquals(Sequence.class, result.getClass());
+        Sequence seq = (Sequence) result;
+
+        List<KeyedFilter> queries = seq.queries();
+        assertEquals(rule1, queries.get(0));
+
+        KeyedFilter query2 = queries.get(1);
+        LogicalPlan child2 = query2.child();
+
+        Expression keyRuleBCondition = gtExpression(b);
+
+        assertEquals(keyRuleBCondition, filterCondition(child2));
+        assertEquals(filter, filterCondition(child2.children().get(0)));
+    }
+
+    private static Attribute timestamp() {
+        return new FieldAttribute(EMPTY, "test", new EsField("field", INTEGER, emptyMap(), true));
+    }
+
+    private static Attribute tiebreaker() {
+        return new EmptyAttribute(EMPTY);
+    }
+
+    private static LogicalPlan rel() {
+        return new UnresolvedRelation(EMPTY, new TableIdentifier(EMPTY, "catalog", "index"), "", false);
+    }
+
+    private static KeyedFilter keyedFilter(LogicalPlan child) {
+        return new KeyedFilter(EMPTY, child, emptyList(), timestamp(), tiebreaker());
+    }
+
+    private static KeyedFilter keyedFilter(LogicalPlan child, NamedExpression... keys) {
+        return new KeyedFilter(EMPTY, child, asList(keys), timestamp(), tiebreaker());
+    }
+
+    private static Attribute key(String name) {
+        return new FieldAttribute(EMPTY, name, new EsField(name, INTEGER, emptyMap(), true));
+    }
+
+    private static void assertOrder(UnaryPlan plan, OrderDirection direction) {
         assertEquals(OrderBy.class, plan.child().getClass());
         assertEquals(OrderBy.class, plan.child().getClass());
         OrderBy orderBy = (OrderBy) plan.child();
         OrderBy orderBy = (OrderBy) plan.child();
         Order order = orderBy.order().get(0);
         Order order = orderBy.order().get(0);
         assertEquals(direction, order.direction());
         assertEquals(direction, order.direction());
     }
     }
 
 
-    private LogicalPlan defaultPipes(LogicalPlan plan) {
+    private static LogicalPlan defaultPipes(LogicalPlan plan) {
         assertTrue(plan instanceof Project);
         assertTrue(plan instanceof Project);
         plan = ((Project) plan).child();
         plan = ((Project) plan).child();
         assertTrue(plan instanceof LimitWithOffset);
         assertTrue(plan instanceof LimitWithOffset);
@@ -300,4 +672,30 @@ public class OptimizerTests extends ESTestCase {
         assertTrue(plan instanceof OrderBy);
         assertTrue(plan instanceof OrderBy);
         return ((OrderBy) plan).child();
         return ((OrderBy) plan).child();
     }
     }
-}
+
+    private static Sequence sequence(LogicalPlan... rules) {
+        List<KeyedFilter> collect = Stream.of(rules)
+            .map(r -> r instanceof KeyedFilter ? (KeyedFilter) r : keyedFilter(r))
+            .collect(toList());
+
+        return new Sequence(EMPTY, collect, keyedFilter(rel()), TimeValue.MINUS_ONE, timestamp(), tiebreaker(), OrderDirection.ASC);
+    }
+
+    private static Expression filterCondition(LogicalPlan plan) {
+        assertEquals(Filter.class, plan.getClass());
+        Filter f = (Filter) plan;
+        return f.condition();
+    }
+
+    private static Filter basicFilter(Expression filter) {
+        return new Filter(EMPTY, rel(), filter);
+    }
+
+    private static Equals equalsExpression() {
+        return new Equals(EMPTY, timestamp(), TRUE);
+    }
+
+    private static GreaterThan gtExpression(Attribute b) {
+        return new GreaterThan(EMPTY, b, new Literal(EMPTY, 1, INTEGER), UTC);
+    }
+}