Browse Source

ESQL: extract common filter from aggs (#115678) (#116833)

This adds a new optimiser rule to extract the filters from aggs, if the
same one is provided for all of them, pushing it under the agg. This
allows for combining the filter further or pushing it down to source.
Example:

```
 ... | STATS MIN(a) WHERE b > 0, MIN(c) WHERE b > 0 | ...
=>
 ... | WHERE b > 0 | STATS MIN(a), MIN(c) | ...
```

Related: #114352.
Bogdan Pintea 11 months ago
parent
commit
343dcc0de4

+ 5 - 0
docs/changelog/115678.yaml

@@ -0,0 +1,5 @@
+pr: 115678
+summary: "ESQL: extract common filter from aggs"
+area: ES|QL
+type: enhancement
+issues: []

+ 41 - 0
x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/Predicates.java

@@ -6,7 +6,9 @@
  */
 package org.elasticsearch.xpack.esql.core.expression.predicate;
 
+import org.elasticsearch.core.Tuple;
 import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.Literal;
 import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And;
 import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or;
 
@@ -113,4 +115,43 @@ public abstract class Predicates {
         }
         return diff.isEmpty() ? emptyList() : diff;
     }
+
+    /**
+     * Given a list of expressions of predicates, extract a new expression of
+     * all the common ones and return it, along the original list with the
+     * common ones removed.
+     * <p>
+     * Example: for ['field1 > 0 AND field2 > 0', 'field1 > 0 AND field3 > 0',
+     * 'field1 > 0'], the function will return 'field1 > 0' as the common
+     * predicate expression and ['field2 > 0', 'field3 > 0', Literal.TRUE] as
+     * the left predicates list.
+     *
+     * @param expressions list of expressions to extract common predicates from.
+     * @return a tuple having as the first element an expression of the common
+     * predicates and as the second element the list of expressions with the
+     * common predicates removed. If there are no common predicates, `null` will
+     * be returned as the first element and the original list as the second. If
+     * for one of the expressions in the input list, nothing is left after
+     * trimming the common predicates, it will be replaced with Literal.TRUE.
+     */
+    public static Tuple<Expression, List<Expression>> extractCommon(List<Expression> expressions) {
+        List<Expression> common = null;
+        List<List<Expression>> splitAnds = new ArrayList<>(expressions.size());
+        for (var expression : expressions) {
+            var split = splitAnd(expression);
+            common = common == null ? split : inCommon(split, common);
+            if (common.isEmpty()) {
+                return Tuple.tuple(null, expressions);
+            }
+            splitAnds.add(split);
+        }
+
+        List<Expression> trimmed = new ArrayList<>(expressions.size());
+        final List<Expression> finalCommon = common;
+        splitAnds.forEach(split -> {
+            var subtracted = subtract(split, finalCommon);
+            trimmed.add(subtracted.isEmpty() ? Literal.TRUE : combineAnd(subtracted));
+        });
+        return Tuple.tuple(combineAnd(common), trimmed);
+    }
 }

+ 51 - 0
x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec

@@ -2641,6 +2641,57 @@ c2:l |c2_f:l |m2:i |m2_f:i |c:l
 1    |1      |5    |5      |21
 ;
 
+commonFilterExtractionWithAliasing
+required_capability: per_agg_filtering
+from employees
+| eval eno = emp_no
+| drop emp_no
+| stats min_sal = min(salary) where eno <= 10010,
+        min_hei = min(height) where eno <= 10010
+;
+
+min_sal:integer |min_hei:double
+36174           |1.56
+;
+
+commonFilterExtractionWithAliasAndOriginal
+required_capability: per_agg_filtering
+from employees
+| eval eno = emp_no
+| stats min_sal = min(salary) where eno <= 10010,
+        min_hei = min(height) where emp_no <= 10010
+;
+
+// same results as above in commonFilterExtractionWithAliasing
+min_sal:integer |min_hei:double
+36174           |1.56
+;
+
+commonFilterExtractionWithAliasAndOriginalNeedingNormalization
+required_capability: per_agg_filtering
+from employees
+| eval eno = emp_no
+| stats min_sal = min(salary) where eno <= 10010,
+        min_hei = min(height) where emp_no <= 10010,
+        max_hei = max(height) where 10010 >= emp_no
+;
+
+min_sal:integer |min_hei:double |max_hei:double
+36174           |1.56           |2.1
+;
+
+commonFilterExtractionWithAliasAndOriginalNeedingNormalizationAndSimplification
+required_capability: per_agg_filtering
+from employees
+| eval eno = emp_no
+| stats min_sal = min(salary) where eno <= 10010,
+        min_hei = min(height) where not (emp_no > 10010),
+        max_hei = max(height) where 10010 >= emp_no
+;
+
+min_sal:integer |min_hei:double |max_hei:double
+36174           |1.56           |2.1
+;
 
 statsByConstantExpressionNoAggs
 required_capability: fix_stats_by_foldable_expression

+ 6 - 2
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java

@@ -19,6 +19,7 @@ import org.elasticsearch.xpack.esql.optimizer.rules.logical.CombineEvals;
 import org.elasticsearch.xpack.esql.optimizer.rules.logical.CombineProjections;
 import org.elasticsearch.xpack.esql.optimizer.rules.logical.ConstantFolding;
 import org.elasticsearch.xpack.esql.optimizer.rules.logical.ConvertStringToByteRef;
+import org.elasticsearch.xpack.esql.optimizer.rules.logical.ExtractAggregateCommonFilter;
 import org.elasticsearch.xpack.esql.optimizer.rules.logical.FoldNull;
 import org.elasticsearch.xpack.esql.optimizer.rules.logical.LiteralsOnTheRight;
 import org.elasticsearch.xpack.esql.optimizer.rules.logical.PartiallyFoldCase;
@@ -124,8 +125,9 @@ public class LogicalPlanOptimizer extends ParameterizedRuleExecutor<LogicalPlan,
             "Substitutions",
             Limiter.ONCE,
             new SubstituteSurrogatePlans(),
-            // translate filtered expressions into aggregate with filters - can't use surrogate expressions because it was
-            // retrofitted for constant folding - this needs to be fixed
+            // Translate filtered expressions into aggregate with filters - can't use surrogate expressions because it was
+            // retrofitted for constant folding - this needs to be fixed.
+            // Needs to occur before ReplaceAggregateAggExpressionWithEval, which will update the functions, losing the filter.
             new SubstituteFilteredExpression(),
             new RemoveStatsOverride(),
             // first extract nested expressions inside aggs
@@ -170,8 +172,10 @@ public class LogicalPlanOptimizer extends ParameterizedRuleExecutor<LogicalPlan,
             new BooleanFunctionEqualsElimination(),
             new CombineBinaryComparisons(),
             new CombineDisjunctions(),
+            // TODO: bifunction can now (since we now have just one data types set) be pushed into the rule
             new SimplifyComparisonsArithmetics(DataType::areCompatible),
             new ReplaceStatsFilteredAggWithEval(),
+            new ExtractAggregateCommonFilter(),
             // prune/elimination
             new PruneFilters(),
             new PruneColumns(),

+ 78 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ExtractAggregateCommonFilter.java

@@ -0,0 +1,78 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.optimizer.rules.logical;
+
+import org.elasticsearch.xpack.esql.core.expression.Alias;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
+import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
+import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
+import org.elasticsearch.xpack.esql.plan.logical.Filter;
+import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import static org.elasticsearch.xpack.esql.core.expression.predicate.Predicates.extractCommon;
+
+/**
+ * Extract a per-function expression filter applied to all the aggs as a query {@link Filter}, when no groups are provided.
+ * <p>
+ *     Example:
+ *     <pre>
+ *         ... | STATS MIN(a) WHERE b > 0, MIN(c) WHERE b > 0 | ...
+ *         =>
+ *         ... | WHERE b > 0 | STATS MIN(a), MIN(c) | ...
+ *     </pre>
+ */
+public final class ExtractAggregateCommonFilter extends OptimizerRules.OptimizerRule<Aggregate> {
+    public ExtractAggregateCommonFilter() {
+        super(OptimizerRules.TransformDirection.UP);
+    }
+
+    @Override
+    protected LogicalPlan rule(Aggregate aggregate) {
+        if (aggregate.groupings().isEmpty() == false) {
+            return aggregate; // no optimization for grouped stats
+        }
+
+        // collect all filters from the agg functions
+        List<Expression> filters = new ArrayList<>(aggregate.aggregates().size());
+        for (NamedExpression ne : aggregate.aggregates()) {
+            if (ne instanceof Alias alias && alias.child() instanceof AggregateFunction aggFunction && aggFunction.hasFilter()) {
+                filters.add(aggFunction.filter());
+            } else {
+                return aggregate; // (at least one) agg function has no filter -- skip optimization
+            }
+        }
+
+        // extract common filters
+        var common = extractCommon(filters);
+        if (common.v1() == null) { // no common filter
+            return aggregate;
+        }
+
+        // replace agg functions' filters with trimmed ones
+        var newFilters = common.v2();
+        List<NamedExpression> newAggs = new ArrayList<>(aggregate.aggregates().size());
+        for (int i = 0; i < aggregate.aggregates().size(); i++) {
+            var alias = (Alias) aggregate.aggregates().get(i);
+            var newChild = ((AggregateFunction) alias.child()).withFilter(newFilters.get(i));
+            newAggs.add(alias.replaceChild(newChild));
+        }
+
+        // build the new agg on top of extracted filter
+        return new Aggregate(
+            aggregate.source(),
+            new Filter(aggregate.source(), aggregate.child(), common.v1()),
+            aggregate.aggregateType(),
+            aggregate.groupings(),
+            newAggs
+        );
+    }
+}

+ 259 - 0
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java

@@ -840,6 +840,265 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
         var source = as(aggregate.child(), EsRelation.class);
     }
 
+    public void testExtractStatsCommonFilter() {
+        var plan = plan("""
+            from test
+            | stats m = min(salary) where emp_no > 1,
+                    max(salary) where emp_no > 1
+            """);
+
+        var limit = as(plan, Limit.class);
+        var agg = as(limit.child(), Aggregate.class);
+        assertThat(agg.aggregates().size(), is(2));
+
+        var alias = as(agg.aggregates().get(0), Alias.class);
+        var aggFunc = as(alias.child(), AggregateFunction.class);
+        assertThat(aggFunc.filter(), is(Literal.TRUE));
+
+        alias = as(agg.aggregates().get(1), Alias.class);
+        aggFunc = as(alias.child(), AggregateFunction.class);
+        assertThat(aggFunc.filter(), is(Literal.TRUE));
+
+        var filter = as(agg.child(), Filter.class);
+        assertThat(Expressions.name(filter.condition()), is("emp_no > 1"));
+
+        var source = as(filter.child(), EsRelation.class);
+    }
+
+    public void testExtractStatsCommonFilterUsingAliases() {
+        var plan = plan("""
+            from test
+            | eval eno = emp_no
+            | drop emp_no
+            | stats min(salary) where eno > 1,
+                    max(salary) where eno > 1
+            """);
+
+        var limit = as(plan, Limit.class);
+        var agg = as(limit.child(), Aggregate.class);
+        assertThat(agg.aggregates().size(), is(2));
+
+        var alias = as(agg.aggregates().get(0), Alias.class);
+        var aggFunc = as(alias.child(), AggregateFunction.class);
+        assertThat(aggFunc.filter(), is(Literal.TRUE));
+
+        alias = as(agg.aggregates().get(1), Alias.class);
+        aggFunc = as(alias.child(), AggregateFunction.class);
+        assertThat(aggFunc.filter(), is(Literal.TRUE));
+
+        var filter = as(agg.child(), Filter.class);
+        assertThat(Expressions.name(filter.condition()), is("eno > 1"));
+
+        var source = as(filter.child(), EsRelation.class);
+    }
+
+    public void testExtractStatsCommonFilterUsingJustOneAlias() {
+        var plan = plan("""
+            from test
+            | eval eno = emp_no
+            | stats min(salary) where emp_no > 1,
+                    max(salary) where eno > 1
+            """);
+
+        var limit = as(plan, Limit.class);
+        var agg = as(limit.child(), Aggregate.class);
+        assertThat(agg.aggregates().size(), is(2));
+
+        var alias = as(agg.aggregates().get(0), Alias.class);
+        var aggFunc = as(alias.child(), AggregateFunction.class);
+        assertThat(aggFunc.filter(), is(Literal.TRUE));
+
+        alias = as(agg.aggregates().get(1), Alias.class);
+        aggFunc = as(alias.child(), AggregateFunction.class);
+        assertThat(aggFunc.filter(), is(Literal.TRUE));
+
+        var filter = as(agg.child(), Filter.class);
+        var gt = as(filter.condition(), GreaterThan.class);
+        assertThat(Expressions.name(gt.left()), is("emp_no"));
+        assertTrue(gt.right().foldable());
+        assertThat(gt.right().fold(), is(1));
+
+        var source = as(filter.child(), EsRelation.class);
+    }
+
+    public void testExtractStatsCommonFilterSkippedNotSameFilter() {
+        var plan = plan("""
+            from test
+            | stats min(salary) where emp_no > 1,
+                    max(salary) where emp_no > 2
+            """);
+
+        var limit = as(plan, Limit.class);
+        var agg = as(limit.child(), Aggregate.class);
+        assertThat(agg.aggregates().size(), is(2));
+
+        var alias = as(agg.aggregates().get(0), Alias.class);
+        var aggFunc = as(alias.child(), AggregateFunction.class);
+        assertThat(aggFunc.filter(), instanceOf(BinaryComparison.class));
+
+        alias = as(agg.aggregates().get(1), Alias.class);
+        aggFunc = as(alias.child(), AggregateFunction.class);
+        assertThat(aggFunc.filter(), instanceOf(BinaryComparison.class));
+
+        var source = as(agg.child(), EsRelation.class);
+    }
+
+    public void testExtractStatsCommonFilterSkippedOnLackingFilter() {
+        var plan = plan("""
+            from test
+            | stats min(salary),
+                    max(salary) where emp_no > 2
+            """);
+
+        var limit = as(plan, Limit.class);
+        var agg = as(limit.child(), Aggregate.class);
+        assertThat(agg.aggregates().size(), is(2));
+
+        var alias = as(agg.aggregates().get(0), Alias.class);
+        var aggFunc = as(alias.child(), AggregateFunction.class);
+        assertThat(aggFunc.filter(), is(Literal.TRUE));
+
+        alias = as(agg.aggregates().get(1), Alias.class);
+        aggFunc = as(alias.child(), AggregateFunction.class);
+        assertThat(aggFunc.filter(), instanceOf(BinaryComparison.class));
+
+        var source = as(agg.child(), EsRelation.class);
+    }
+
+    public void testExtractStatsCommonFilterSkippedWithGroups() {
+        var plan = plan("""
+            from test
+            | stats min(salary) where emp_no > 2,
+                    max(salary) where emp_no > 2 by first_name
+            """);
+
+        var limit = as(plan, Limit.class);
+        var agg = as(limit.child(), Aggregate.class);
+        assertThat(agg.aggregates().size(), is(3));
+
+        var alias = as(agg.aggregates().get(0), Alias.class);
+        var aggFunc = as(alias.child(), AggregateFunction.class);
+        assertThat(aggFunc.filter(), instanceOf(BinaryComparison.class));
+
+        alias = as(agg.aggregates().get(1), Alias.class);
+        aggFunc = as(alias.child(), AggregateFunction.class);
+        assertThat(aggFunc.filter(), instanceOf(BinaryComparison.class));
+
+        var source = as(agg.child(), EsRelation.class);
+    }
+
+    public void testExtractStatsCommonFilterNormalizeAndCombineWithExistingFilter() {
+        var plan = plan("""
+            from test
+            | where emp_no > 3
+            | stats min(salary) where emp_no > 2,
+                    max(salary) where 2 < emp_no
+            """);
+
+        var limit = as(plan, Limit.class);
+        var agg = as(limit.child(), Aggregate.class);
+        assertThat(agg.aggregates().size(), is(2));
+
+        var alias = as(agg.aggregates().get(0), Alias.class);
+        var aggFunc = as(alias.child(), AggregateFunction.class);
+        assertThat(aggFunc.filter(), is(Literal.TRUE));
+
+        alias = as(agg.aggregates().get(1), Alias.class);
+        aggFunc = as(alias.child(), AggregateFunction.class);
+        assertThat(aggFunc.filter(), is(Literal.TRUE));
+
+        var filter = as(agg.child(), Filter.class);
+        assertThat(Expressions.name(filter.condition()), is("emp_no > 3"));
+
+        var source = as(filter.child(), EsRelation.class);
+    }
+
+    public void testExtractStatsCommonFilterInConjunction() {
+        var plan = plan("""
+            from test
+            | stats min(salary) where emp_no > 2 and first_name == "John",
+                    max(salary) where emp_no > 1 + 1 and length(last_name) < 19
+            """);
+
+        var limit = as(plan, Limit.class);
+        var agg = as(limit.child(), Aggregate.class);
+        assertThat(agg.aggregates().size(), is(2));
+
+        var alias = as(agg.aggregates().get(0), Alias.class);
+        var aggFunc = as(alias.child(), AggregateFunction.class);
+        assertThat(Expressions.name(aggFunc.filter()), is("first_name == \"John\""));
+
+        alias = as(agg.aggregates().get(1), Alias.class);
+        aggFunc = as(alias.child(), AggregateFunction.class);
+        assertThat(Expressions.name(aggFunc.filter()), is("length(last_name) < 19"));
+
+        var filter = as(agg.child(), Filter.class);
+        var gt = as(filter.condition(), GreaterThan.class); // name is "emp_no > 1 + 1"
+        assertThat(Expressions.name(gt.left()), is("emp_no"));
+        assertTrue(gt.right().foldable());
+        assertThat(gt.right().fold(), is(2));
+
+        var source = as(filter.child(), EsRelation.class);
+    }
+
+    public void testExtractStatsCommonFilterInConjunctionWithMultipleCommonConjunctions() {
+        var plan = plan("""
+            from test
+            | stats min(salary) where emp_no < 10 and first_name == "John" and last_name == "Doe",
+                    max(salary) where emp_no - 1 < 2 + 7 and length(last_name) < 19 and last_name == "Doe"
+            """);
+
+        var limit = as(plan, Limit.class);
+        var agg = as(limit.child(), Aggregate.class);
+        assertThat(agg.aggregates().size(), is(2));
+
+        var alias = as(agg.aggregates().get(0), Alias.class);
+        var aggFunc = as(alias.child(), AggregateFunction.class);
+        assertThat(Expressions.name(aggFunc.filter()), is("first_name == \"John\""));
+
+        alias = as(agg.aggregates().get(1), Alias.class);
+        aggFunc = as(alias.child(), AggregateFunction.class);
+        assertThat(Expressions.name(aggFunc.filter()), is("length(last_name) < 19"));
+
+        var filter = as(agg.child(), Filter.class);
+        var and = as(filter.condition(), And.class);
+
+        var lt = as(and.left(), LessThan.class);
+        assertThat(Expressions.name(lt.left()), is("emp_no"));
+        assertTrue(lt.right().foldable());
+        assertThat(lt.right().fold(), is(10));
+
+        var equals = as(and.right(), Equals.class);
+        assertThat(Expressions.name(equals.left()), is("last_name"));
+        assertTrue(equals.right().foldable());
+        assertThat(equals.right().fold(), is(BytesRefs.toBytesRef("Doe")));
+
+        var source = as(filter.child(), EsRelation.class);
+    }
+
+    public void testExtractStatsCommonFilterSkippedDueToDisjunction() {
+        // same query as in testExtractStatsCommonFilterInConjunction, except for the OR in the filter
+        var plan = plan("""
+            from test
+            | stats min(salary) where emp_no > 2 OR first_name == "John",
+                    max(salary) where emp_no > 1 + 1 and length(last_name) < 19
+            """);
+
+        var limit = as(plan, Limit.class);
+        var agg = as(limit.child(), Aggregate.class);
+        assertThat(agg.aggregates().size(), is(2));
+
+        var alias = as(agg.aggregates().get(0), Alias.class);
+        var aggFunc = as(alias.child(), AggregateFunction.class);
+        assertThat(aggFunc.filter(), instanceOf(Or.class));
+
+        alias = as(agg.aggregates().get(1), Alias.class);
+        aggFunc = as(alias.child(), AggregateFunction.class);
+        assertThat(aggFunc.filter(), instanceOf(And.class));
+
+        var source = as(agg.child(), EsRelation.class);
+    }
+
     public void testQlComparisonOptimizationsApply() {
         var plan = plan("""
             from test