Browse Source

QL: Combine multiple equal queries into In (#65353)

Some queries can have an excessive number of OR clauses on the same
field (typically by being generated). This can lead to large query
requests that can cause SO errors.
This PR tries to address this pattern by combining the equals into an
In expression which in turn gets compressed into a terms query vs one
term query per entry.

Fix #62804
Fix #46477
Costin Leau 4 years ago
parent
commit
7505d7d61c

+ 2 - 0
x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/optimizer/Optimizer.java

@@ -36,6 +36,7 @@ import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.BooleanFunctionEquals
 import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.BooleanLiteralsOnTheRight;
 import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.BooleanSimplification;
 import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.CombineBinaryComparisons;
+import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.CombineDisjunctionsToIn;
 import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.ConstantFolding;
 import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.OptimizerRule;
 import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.PropagateEquals;
@@ -83,6 +84,7 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
                 new ReplaceNullChecks(),
                 new PropagateEquals(),
                 new CombineBinaryComparisons(),
+                new CombineDisjunctionsToIn(),
                 new PushDownAndCombineFilters(),
                 // prune/elimination
                 new PruneFilters(),

+ 84 - 0
x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/optimizer/OptimizerRules.java

@@ -24,6 +24,7 @@ import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.Binar
 import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.Equals;
 import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.GreaterThan;
 import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.GreaterThanOrEqual;
+import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.In;
 import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.LessThan;
 import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.LessThanOrEqual;
 import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.NotEquals;
@@ -38,9 +39,15 @@ import org.elasticsearch.xpack.ql.rule.Rule;
 import org.elasticsearch.xpack.ql.type.DataTypes;
 import org.elasticsearch.xpack.ql.util.CollectionUtils;
 
+import java.time.ZoneId;
 import java.util.ArrayList;
 import java.util.Iterator;
+import java.util.LinkedHashMap;
+import java.util.LinkedHashSet;
+import java.util.LinkedList;
 import java.util.List;
+import java.util.Map;
+import java.util.Set;
 
 import static org.elasticsearch.xpack.ql.expression.Literal.FALSE;
 import static org.elasticsearch.xpack.ql.expression.Literal.TRUE;
@@ -1036,6 +1043,83 @@ public final class OptimizerRules {
 
     }
 
+    /**
+     * Combine disjunctions on the same field into an In expression.
+     * This rule looks for both simple equalities:
+     * 1. a == 1 OR a == 2 becomes a IN (1, 2)
+     * and combinations of In
+     * 2. a == 1 OR a IN (2) becomes a IN (1, 2)
+     * 3. a IN (1) OR a IN (2) becomes a IN (1, 2)
+     *
+     * This rule does NOT check for type compatibility as that phase has been
+     * already be verified in the analyzer.
+     */
+    public static class CombineDisjunctionsToIn extends OptimizerExpressionRule {
+        public CombineDisjunctionsToIn() {
+            super(TransformDirection.UP);
+        }
+
+        @Override
+        protected Expression rule(Expression e) {
+            if (e instanceof Or) {
+                // look only at equals and In
+                List<Expression> exps = splitOr(e);
+
+                Map<Expression, Set<Expression>> found = new LinkedHashMap<>();
+                ZoneId zoneId = null;
+                List<Expression> ors = new LinkedList<>();
+
+                for (Expression exp : exps) {
+                    if (exp instanceof Equals) {
+                        Equals eq = (Equals) exp;
+                        // consider only equals against foldables
+                        if (eq.right().foldable()) {
+                            found.computeIfAbsent(eq.left(), k -> new LinkedHashSet<>()).add(eq.right());
+                        } else {
+                            ors.add(exp);
+                        }
+                        if (zoneId == null) {
+                            zoneId = eq.zoneId();
+                        }
+                    }
+                    else if (exp instanceof In) {
+                        In in = (In) exp;
+                        found.computeIfAbsent(in.value(), k -> new LinkedHashSet<>()).addAll(in.list());
+                        if (zoneId == null) {
+                            zoneId = in.zoneId();
+                        }
+                    } else {
+                        ors.add(exp);
+                    }
+                }
+
+                if (found.isEmpty() == false) {
+                    // combine equals alongside the existing ors
+                    final ZoneId finalZoneId = zoneId;
+                    found.forEach((k, v) -> {
+                        ors.add(v.size() == 1
+                            ? new Equals(k.source(), k, v.iterator().next(), finalZoneId)
+                            : createIn(k, new ArrayList<>(v), finalZoneId));
+                    });
+
+                    Expression combineOr = combineOr(ors);
+                    // check the result semantically since the result might different in order
+                    // but be actually the same which can trigger a loop
+                    // e.g. a == 1 OR a == 2 OR null --> null OR a in (1,2) --> literalsOnTheRight --> cycle
+                    if (e.semanticEquals(combineOr) == false) {
+                        e = combineOr;
+                    }
+                }
+            }
+
+            return e;
+        }
+
+        protected In createIn(Expression key, List<Expression> values, ZoneId zoneId) {
+            return new In(key.source(), key, values, zoneId);
+        }
+    }
+
     public static class ReplaceSurrogateFunction extends OptimizerExpressionRule {
 
         public ReplaceSurrogateFunction() {

+ 136 - 13
x-pack/plugin/ql/src/test/java/org/elasticsearch/xpack/ql/optimizer/OptimizerRulesTests.java

@@ -28,6 +28,7 @@ import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.Binar
 import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.Equals;
 import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.GreaterThan;
 import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.GreaterThanOrEqual;
+import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.In;
 import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.LessThan;
 import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.LessThanOrEqual;
 import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.NotEquals;
@@ -49,11 +50,12 @@ import org.elasticsearch.xpack.ql.type.EsField;
 import org.elasticsearch.xpack.ql.util.StringUtils;
 
 import java.time.ZoneId;
-import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
 
+import static java.util.Arrays.asList;
 import static java.util.Collections.emptyMap;
+import static java.util.Collections.singletonList;
 import static org.elasticsearch.xpack.ql.TestUtils.equalsOf;
 import static org.elasticsearch.xpack.ql.TestUtils.greaterThanOf;
 import static org.elasticsearch.xpack.ql.TestUtils.greaterThanOrEqualOf;
@@ -66,9 +68,12 @@ import static org.elasticsearch.xpack.ql.TestUtils.rangeOf;
 import static org.elasticsearch.xpack.ql.expression.Literal.FALSE;
 import static org.elasticsearch.xpack.ql.expression.Literal.NULL;
 import static org.elasticsearch.xpack.ql.expression.Literal.TRUE;
+import static org.elasticsearch.xpack.ql.optimizer.OptimizerRules.CombineDisjunctionsToIn;
+import static org.elasticsearch.xpack.ql.optimizer.OptimizerRules.ReplaceRegexMatch;
 import static org.elasticsearch.xpack.ql.tree.Source.EMPTY;
 import static org.elasticsearch.xpack.ql.type.DataTypes.BOOLEAN;
 import static org.elasticsearch.xpack.ql.type.DataTypes.INTEGER;
+import static org.hamcrest.Matchers.contains;
 
 public class OptimizerRulesTests extends ESTestCase {
 
@@ -289,7 +294,7 @@ public class OptimizerRulesTests extends ESTestCase {
 
         FieldAttribute field = getFieldAttribute();
 
-        List<? extends BinaryComparison> comparisons = Arrays.asList(
+        List<? extends BinaryComparison> comparisons = asList(
             new Equals(EMPTY, field, TRUE),
             new Equals(EMPTY, field, FALSE),
             notEqualsOf(field, TRUE),
@@ -504,7 +509,7 @@ public class OptimizerRulesTests extends ESTestCase {
         LessThan blt4 = new LessThan(EMPTY, fb, FOUR, zoneId);
         LessThan clt4 = new LessThan(EMPTY, fc, FOUR, zoneId);
 
-        Expression inputAnd = Predicates.combineAnd(Arrays.asList(agt1, alt3, bgt2, blt4, clt4));
+        Expression inputAnd = Predicates.combineAnd(asList(agt1, alt3, bgt2, blt4, clt4));
 
         CombineBinaryComparisons rule = new CombineBinaryComparisons();
         Expression outputAnd = rule.rule(inputAnd);
@@ -513,7 +518,7 @@ public class OptimizerRulesTests extends ESTestCase {
         Range bgt2lt4 = new Range(EMPTY, fb, TWO, false, FOUR, false, zoneId);
 
         // The actual outcome is (c < 4) AND (1 < a < 3) AND (2 < b < 4), due to the way the Expression types are combined in the Optimizer
-        Expression expectedAnd = Predicates.combineAnd(Arrays.asList(clt4, agt1lt3, bgt2lt4));
+        Expression expectedAnd = Predicates.combineAnd(asList(clt4, agt1lt3, bgt2lt4));
 
         assertTrue(outputAnd.semanticEquals(expectedAnd));
     }
@@ -1188,7 +1193,7 @@ public class OptimizerRulesTests extends ESTestCase {
         NotEquals neq = notEqualsOf(fa, FOUR);
 
         PropagateEquals rule = new PropagateEquals();
-        Expression and = Predicates.combineAnd(Arrays.asList(eq, lt, gt, neq));
+        Expression and = Predicates.combineAnd(asList(eq, lt, gt, neq));
         Expression exp = rule.rule(and);
         assertEquals(eq, exp);
     }
@@ -1202,7 +1207,7 @@ public class OptimizerRulesTests extends ESTestCase {
         NotEquals neq = notEqualsOf(fa, FOUR);
 
         PropagateEquals rule = new PropagateEquals();
-        Expression and = Predicates.combineAnd(Arrays.asList(eq, range, gt, neq));
+        Expression and = Predicates.combineAnd(asList(eq, range, gt, neq));
         Expression exp = rule.rule(and);
         assertEquals(eq, exp);
     }
@@ -1331,7 +1336,7 @@ public class OptimizerRulesTests extends ESTestCase {
         NotEquals neq = notEqualsOf(fa, TWO);
 
         PropagateEquals rule = new PropagateEquals();
-        Expression exp = rule.rule(Predicates.combineOr(Arrays.asList(eq, range, neq, gt)));
+        Expression exp = rule.rule(Predicates.combineOr(asList(eq, range, neq, gt)));
         assertEquals(TRUE, exp);
     }
 
@@ -1339,11 +1344,11 @@ public class OptimizerRulesTests extends ESTestCase {
     // Like / Regex
     //
     public void testMatchAllLikeToExist() throws Exception {
-        for (String s : Arrays.asList("%", "%%", "%%%")) {
+        for (String s : asList("%", "%%", "%%%")) {
             LikePattern pattern = new LikePattern(s, (char) 0);
             FieldAttribute fa = getFieldAttribute();
             Like l = new Like(EMPTY, fa, pattern);
-            Expression e = new OptimizerRules.ReplaceRegexMatch().rule(l);
+            Expression e = new ReplaceRegexMatch().rule(l);
             assertEquals(IsNotNull.class, e.getClass());
             IsNotNull inn = (IsNotNull) e;
             assertEquals(fa, inn.field());
@@ -1354,18 +1359,18 @@ public class OptimizerRulesTests extends ESTestCase {
             RLikePattern pattern = new RLikePattern(".*");
             FieldAttribute fa = getFieldAttribute();
             RLike l = new RLike(EMPTY, fa, pattern);
-            Expression e = new OptimizerRules.ReplaceRegexMatch().rule(l);
+            Expression e = new ReplaceRegexMatch().rule(l);
             assertEquals(IsNotNull.class, e.getClass());
             IsNotNull inn = (IsNotNull) e;
             assertEquals(fa, inn.field());
     }
 
     public void testExactMatchLike() throws Exception {
-        for (String s : Arrays.asList("ab", "ab0%", "ab0_c")) {
+        for (String s : asList("ab", "ab0%", "ab0_c")) {
             LikePattern pattern = new LikePattern(s, '0');
             FieldAttribute fa = getFieldAttribute();
             Like l = new Like(EMPTY, fa, pattern);
-            Expression e = new OptimizerRules.ReplaceRegexMatch().rule(l);
+            Expression e = new ReplaceRegexMatch().rule(l);
             assertEquals(Equals.class, e.getClass());
             Equals eq = (Equals) e;
             assertEquals(fa, eq.left());
@@ -1377,10 +1382,128 @@ public class OptimizerRulesTests extends ESTestCase {
         RLikePattern pattern = new RLikePattern("abc");
         FieldAttribute fa = getFieldAttribute();
         RLike l = new RLike(EMPTY, fa, pattern);
-        Expression e = new OptimizerRules.ReplaceRegexMatch().rule(l);
+        Expression e = new ReplaceRegexMatch().rule(l);
         assertEquals(Equals.class, e.getClass());
         Equals eq = (Equals) e;
         assertEquals(fa, eq.left());
         assertEquals("abc", eq.right().fold());
     }
+
+    //
+    // CombineDisjunction in Equals
+    //
+    public void testTwoEqualsWithOr() throws Exception {
+        FieldAttribute fa = getFieldAttribute();
+        Literal one = of(1);
+        Literal two = of(2);
+
+        Or or = new Or(EMPTY, equalsOf(fa, one), equalsOf(fa, two));
+        Expression e = new CombineDisjunctionsToIn().rule(or);
+        assertEquals(In.class, e.getClass());
+        In in = (In) e;
+        assertEquals(fa, in.value());
+        assertThat(in.list(), contains(one, two));
+    }
+
+    public void testTwoEqualsWithSameValue() throws Exception {
+        FieldAttribute fa = getFieldAttribute();
+        Literal one = of(1);
+
+        Or or = new Or(EMPTY, equalsOf(fa, one), equalsOf(fa, one));
+        Expression e = new CombineDisjunctionsToIn().rule(or);
+        assertEquals(Equals.class, e.getClass());
+        Equals eq = (Equals) e;
+        assertEquals(fa, eq.left());
+        assertEquals(one, eq.right());
+    }
+
+    public void testOneEqualsOneIn() throws Exception {
+        FieldAttribute fa = getFieldAttribute();
+        Literal one = of(1);
+        Literal two = of(2);
+
+        Or or = new Or(EMPTY, equalsOf(fa, one), new In(EMPTY, fa, singletonList(two)));
+        Expression e = new CombineDisjunctionsToIn().rule(or);
+        assertEquals(In.class, e.getClass());
+        In in = (In) e;
+        assertEquals(fa, in.value());
+        assertThat(in.list(), contains(one, two));
+    }
+
+    public void testOneEqualsOneInWithSameValue() throws Exception {
+        FieldAttribute fa = getFieldAttribute();
+        Literal one = of(1);
+        Literal two = of(2);
+
+        Or or = new Or(EMPTY, equalsOf(fa, one), new In(EMPTY, fa, asList(one, two)));
+        Expression e = new CombineDisjunctionsToIn().rule(or);
+        assertEquals(In.class, e.getClass());
+        In in = (In) e;
+        assertEquals(fa, in.value());
+        assertThat(in.list(), contains(one, two));
+    }
+
+    public void testSingleValueInToEquals() throws Exception {
+        FieldAttribute fa = getFieldAttribute();
+        Literal one = of(1);
+
+        Equals equals = equalsOf(fa, one);
+        Or or = new Or(EMPTY, equals, new In(EMPTY, fa, singletonList(one)));
+        Expression e = new CombineDisjunctionsToIn().rule(or);
+        assertEquals(equals, e);
+    }
+
+    public void testEqualsBehindAnd() throws Exception {
+        FieldAttribute fa = getFieldAttribute();
+        Literal one = of(1);
+        Literal two = of(2);
+
+        And and = new And(EMPTY, equalsOf(fa, one), equalsOf(fa, two));
+        Expression e = new CombineDisjunctionsToIn().rule(and);
+        assertEquals(and, e);
+    }
+
+    public void testTwoEqualsDifferentFields() throws Exception {
+        FieldAttribute fieldOne = getFieldAttribute("one");
+        FieldAttribute fieldTwo = getFieldAttribute("two");
+        Literal one = of(1);
+        Literal two = of(2);
+
+        Or or = new Or(EMPTY, equalsOf(fieldOne, one), equalsOf(fieldTwo, two));
+        Expression e = new CombineDisjunctionsToIn().rule(or);
+        assertEquals(or, e);
+    }
+
+    public void testMultipleIn() throws Exception {
+        FieldAttribute fa = getFieldAttribute();
+        Literal one = of(1);
+        Literal two = of(2);
+        Literal three = of(3);
+
+        Or firstOr = new Or(EMPTY, new In(EMPTY, fa, singletonList(one)), new In(EMPTY, fa, singletonList(two)));
+        Or secondOr = new Or(EMPTY, firstOr, new In(EMPTY, fa, singletonList(three)));
+        Expression e = new CombineDisjunctionsToIn().rule(secondOr);
+        assertEquals(In.class, e.getClass());
+        In in = (In) e;
+        assertEquals(fa, in.value());
+        assertThat(in.list(), contains(one, two, three));
+    }
+
+    public void testOrWithNonCombinableExpressions() throws Exception {
+        FieldAttribute fa = getFieldAttribute();
+        Literal one = of(1);
+        Literal two = of(2);
+        Literal three = of(3);
+
+        Or firstOr = new Or(EMPTY, new In(EMPTY, fa, singletonList(one)), lessThanOf(fa, two));
+        Or secondOr = new Or(EMPTY, firstOr, new In(EMPTY, fa, singletonList(three)));
+        Expression e = new CombineDisjunctionsToIn().rule(secondOr);
+        assertEquals(Or.class, e.getClass());
+        Or or = (Or) e;
+        assertEquals(or.left(), firstOr.right());
+        assertEquals(In.class, or.right().getClass());
+        In in = (In) or.right();
+        assertEquals(fa, in.value());
+        assertThat(in.list(), contains(one, three));
+    }
 }

+ 11 - 1
x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/optimizer/Optimizer.java

@@ -32,7 +32,6 @@ import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.LessT
 import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.LessThanOrEqual;
 import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.NotEquals;
 import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.NullEquals;
-import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.ReplaceRegexMatch;
 import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.BooleanLiteralsOnTheRight;
 import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.BooleanSimplification;
 import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.CombineBinaryComparisons;
@@ -41,6 +40,7 @@ import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.OptimizerExpressionRu
 import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.OptimizerRule;
 import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.PropagateEquals;
 import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.PruneLiteralsInOrderBy;
+import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.ReplaceRegexMatch;
 import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.SetAsOptimized;
 import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.TransformDirection;
 import org.elasticsearch.xpack.ql.plan.logical.Aggregate;
@@ -88,6 +88,7 @@ import org.elasticsearch.xpack.sql.plan.logical.SubQueryAlias;
 import org.elasticsearch.xpack.sql.session.EmptyExecutable;
 import org.elasticsearch.xpack.sql.session.SingletonExecutable;
 
+import java.time.ZoneId;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.HashMap;
@@ -143,6 +144,7 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
                 // needs to occur before BinaryComparison combinations (see class)
                 new PropagateEquals(),
                 new CombineBinaryComparisons(),
+                new CombineDisjunctionsToIn(),
                 // prune/elimination
                 new PruneLiteralsInGroupBy(),
                 new PruneDuplicatesInGroupBy(),
@@ -236,6 +238,14 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
         }
     }
 
+    static class CombineDisjunctionsToIn extends org.elasticsearch.xpack.ql.optimizer.OptimizerRules.CombineDisjunctionsToIn {
+
+        @Override
+        protected In createIn(Expression key, List<Expression> values, ZoneId zoneId) {
+            return new In(key.source(), key, values, zoneId);
+        }
+    }
+
     static class PruneLiteralsInGroupBy extends OptimizerRule<Aggregate> {
 
         @Override

+ 8 - 0
x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryTranslatorTests.java

@@ -2338,4 +2338,12 @@ public class QueryTranslatorTests extends ESTestCase {
         PhysicalPlan p = testContext.optimizeAndPlan("SELECT long FROM test WHERE long IN (1, 2, 3, " + Long.MAX_VALUE + ", 5, 6, 7)");
         assertEquals(EsQueryExec.class, p.getClass());
     }
+
+    public void testEqualsAndInOnTheSameField() {
+        PhysicalPlan physicalPlan = optimizeAndPlan("SELECT int FROM test WHERE int in (1, 2) OR int = 3 OR int = 2");
+        assertEquals(EsQueryExec.class, physicalPlan.getClass());
+        EsQueryExec eqe = (EsQueryExec) physicalPlan;
+        assertEquals(1, eqe.output().size());
+        assertThat(eqe.queryContainer().toString().replaceAll("\\s+", ""), containsString("\"terms\":{\"int\":[1,2,3],"));
+    }
 }