Parcourir la source

[9.1] Propagates filter() to aggregation functions' surrogates (#134461)

---------

Co-authored-by: Jan Kuipers <jan.kuipers@elastic.co>
Co-authored-by: Jan Kuipers <148754765+jan-elastic@users.noreply.github.com>
Nacho Cordón il y a 2 semaines
Parent
commit
a41045dbb1

+ 6 - 0
docs/changelog/134461.yaml

@@ -0,0 +1,6 @@
+pr: 134461
+summary: Propagates filter() to aggregation functions' surrogates
+area: Aggregations
+type: bug
+issues:
+ - 134380

+ 179 - 0
x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec

@@ -3122,3 +3122,182 @@ FROM employees
 m:datetime               | x:integer | d:boolean
 1999-04-30T00:00:00.000Z | 2         | true
 ;
+
+sumWithConditions
+required_capability: stats_with_filtered_surrogate_fixed
+required_capability: aggregate_metric_double_convert_to
+
+FROM employees
+| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(1)
+| STATS sum1 = SUM(1),
+        sum2 = SUM(1) WHERE emp_no == 10080,
+        sum3 = SUM(1) WHERE emp_no < 10080,
+        sum4 = SUM(1) WHERE emp_no >= 10080,
+        sum5 = SUM(agg_metric),
+        sum6 = SUM(agg_metric) WHERE emp_no == 10080
+;
+
+sum1:long | sum2:long | sum3:long | sum4:long | sum5:double | sum6:double
+100       | 1         | 79        | 21        | 100.0       | 1.0
+;
+
+weightedAvgWithConditions
+required_capability: stats_with_filtered_surrogate_fixed
+
+ROW x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+| MV_EXPAND x
+| STATS w_avg1 = WEIGHTED_AVG(x, 1) WHERE x == 5,
+        w_avg2 = WEIGHTED_AVG(x, x) WHERE x == 5,
+        w_avg3 = WEIGHTED_AVG(x, 2) WHERE x <= 5,
+        w_avg4 = WEIGHTED_AVG(x, x) WHERE x > 5,
+        w_avg5 = WEIGHTED_AVG([1,2,3], 1),
+        w_avg6 = WEIGHTED_AVG([1,2,3], 1) WHERE x == 5
+;
+
+w_avg1:double | w_avg2:double | w_avg3:double | w_avg4:double | w_avg5:double | w_avg6:double
+5.0           | 5.0           | 3.0           | 8.25          | 2.0           | 2.0
+;
+
+maxWithConditions
+required_capability: stats_with_filtered_surrogate_fixed
+required_capability: aggregate_metric_double_convert_to
+
+ROW x = [1, 2, 3, 4, 5]
+| MV_EXPAND x
+| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(x)
+| STATS max1 = MAX(agg_metric) WHERE x <= 3,
+        max2 = MAX(agg_metric),
+        max3 = MAX(x),
+        max4 = MAX(x) WHERE x > 3
+;
+
+max1:double | max2:double | max3:integer | max4:integer
+3.0         | 5.0         | 5            | 5
+;
+
+minWithConditions
+required_capability: stats_with_filtered_surrogate_fixed
+required_capability: aggregate_metric_double_convert_to
+
+ROW x = [1, 2, 3, 4, 5]
+| MV_EXPAND x
+| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(x)
+| STATS min1 = MIN(agg_metric) WHERE x <= 3,
+        min2 = MIN(agg_metric),
+        min3 = MIN(x),
+        min4 = MIN(x) WHERE x > 3
+;
+
+min1:double | min2:double | min3:integer | min4:integer
+1.0         | 1.0         | 1            | 4
+;
+
+countWithConditions
+required_capability: stats_with_filtered_surrogate_fixed
+required_capability: aggregate_metric_double_convert_to
+
+ROW x = [1, 2, 3, 4, 5]
+| MV_EXPAND x
+| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(x)
+| STATS count1 = COUNT(x) WHERE x >= 3,
+        count2 = COUNT(x),
+        count3 = COUNT(agg_metric),
+        count4 = COUNT(agg_metric) WHERE x >=3,
+        count5 = COUNT(4) WHERE x >= 3,
+        count6 = COUNT(*) WHERE x >= 3,
+        count7 = COUNT([1,2,3]) WHERE x >= 3,
+        count8 = COUNT([1,2,3])
+;
+
+count1:long | count2:long | count3:long | count4:long | count5:long | count6:long | count7:long | count8:long
+3           | 5           | 5           | 3           | 3           | 3           | 9           | 15
+;
+
+countDistinctWithConditions
+required_capability: stats_with_filtered_surrogate_fixed
+
+ROW x = [1, 2, 3, 4, 5]
+| MV_EXPAND x
+| STATS count1 = COUNT_DISTINCT(x) WHERE x <= 3,
+        count2 = COUNT_DISTINCT(x),
+        count3 = COUNT_DISTINCT(1) WHERE x <= 3,
+        count4 = COUNT_DISTINCT(1)
+;
+
+count1:long | count2:long | count3:long | count4:long
+3           | 5           | 1           | 1
+;
+
+avgWithConditions
+required_capability: stats_with_filtered_surrogate_fixed
+required_capability: aggregate_metric_double_convert_to
+
+ROW x = [1, 2, 3, 4, 5]
+| MV_EXPAND x
+| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(x)
+| STATS avg1 = AVG(x) WHERE x <= 3,
+        avg2 = AVG(x),
+        avg3 = AVG(agg_metric) WHERE x <=3,
+        avg4 = AVG(agg_metric)
+;
+
+avg1:double | avg2:double | avg3:double | avg4:double
+2.0         | 3.0         | 2.0         | 3.0
+;
+
+percentileWithConditions
+required_capability: stats_with_filtered_surrogate_fixed
+
+ROW x = [1, 2, 3, 4, 5]
+| MV_EXPAND x
+| STATS percentile1 = PERCENTILE(x, 50) WHERE x <= 3,
+        percentile2 = PERCENTILE(x, 50)
+;
+
+percentile1:double | percentile2:double
+2.0                | 3.0
+;
+
+medianWithConditions
+required_capability: stats_with_filtered_surrogate_fixed
+
+ROW x = [1, 2, 3, 4, 5]
+| MV_EXPAND x
+| STATS median1 = MEDIAN(x) WHERE x <= 3,
+        median2 = MEDIAN(x),
+        median3 = MEDIAN([5,6,7,8,9]) WHERE x <= 3,
+        median4 = MEDIAN([5,6,7,8,9])
+;
+
+median1:double | median2:double | median3:double | median4:double
+2.0            | 3.0            | 7.0            | 7.0
+;
+
+medianAbsoluteDeviationWithConditions
+required_capability: stats_with_filtered_surrogate_fixed
+
+ROW x = [1, 3, 4, 7, 11, 18]
+| MV_EXPAND x
+| STATS median_dev1 = MEDIAN_ABSOLUTE_DEVIATION(x) WHERE x <= 3,
+        median_dev2 = MEDIAN_ABSOLUTE_DEVIATION(x),
+        median_dev3 = MEDIAN_ABSOLUTE_DEVIATION([3, 11, 14, 25]) WHERE x <= 3,
+        median_dev4 = MEDIAN_ABSOLUTE_DEVIATION([3, 11, 14, 25])
+;
+
+median_dev1:double | median_dev2:double | median_dev3:double | median_dev4:double
+1.0                | 3.5                | 5.5                | 5.5
+;
+
+topWithConditions
+required_capability: stats_with_filtered_surrogate_fixed
+
+FROM employees
+| STATS min1 = TOP(emp_no, 1, "ASC") WHERE emp_no > 10010,
+        min2 = TOP(emp_no, 2, "ASC") WHERE emp_no > 10010,
+        max1 = TOP(emp_no, 1, "DESC") WHERE emp_no < 10080,
+        max2 = TOP(emp_no, 2, "DESC") WHERE emp_no < 10080
+;
+
+min1:integer | min2:integer   | max1:integer | max2:integer
+10011        | [10011, 10012] | 10079        | [10079, 10078]
+;

+ 9 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java

@@ -1267,7 +1267,15 @@ public class EsqlCapabilities {
         /**
          * Support correct counting of skipped shards.
          */
-        CORRECT_SKIPPED_SHARDS_COUNT;
+        CORRECT_SKIPPED_SHARDS_COUNT,
+
+        /**
+         * Bugfix for STATS {{expression}} WHERE {{condition}} when the
+         * expression is replaced by something else on planning
+         * e.g. STATS SUM(1) WHERE x==3 is replaced by
+         *      STATS MV_SUM(const)*COUNT(*) WHERE x == 3.
+         */
+        STATS_WITH_FILTERED_SURROGATE_FIXED;
 
         private final boolean enabled;
 

+ 6 - 2
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Count.java

@@ -146,7 +146,11 @@ public class Count extends AggregateFunction implements ToAggregator, SurrogateE
         var s = source();
         var field = field();
         if (field.dataType() == DataType.AGGREGATE_METRIC_DOUBLE) {
-            return new Sum(s, FromAggregateMetricDouble.withMetric(source(), field, AggregateMetricDoubleBlockBuilder.Metric.COUNT));
+            return new Sum(
+                s,
+                FromAggregateMetricDouble.withMetric(source(), field, AggregateMetricDoubleBlockBuilder.Metric.COUNT),
+                filter()
+            );
         }
 
         if (field.foldable()) {
@@ -163,7 +167,7 @@ public class Count extends AggregateFunction implements ToAggregator, SurrogateE
             return new Mul(
                 s,
                 new Coalesce(s, new MvCount(s, field), List.of(new Literal(s, 0, DataType.INTEGER))),
-                new Count(s, Literal.keyword(s, StringUtils.WILDCARD))
+                new Count(s, Literal.keyword(s, StringUtils.WILDCARD), filter())
             );
         }
 

+ 5 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Max.java

@@ -153,7 +153,11 @@ public class Max extends AggregateFunction implements ToAggregator, SurrogateExp
     @Override
     public Expression surrogate() {
         if (field().dataType() == DataType.AGGREGATE_METRIC_DOUBLE) {
-            return new Max(source(), FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MAX));
+            return new Max(
+                source(),
+                FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MAX),
+                filter()
+            );
         }
         return field().foldable() ? new MvMax(source(), field()) : null;
     }

+ 1 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Median.java

@@ -117,6 +117,6 @@ public class Median extends AggregateFunction implements SurrogateExpression {
 
         return field.foldable()
             ? new MvMedian(s, new ToDouble(s, field))
-            : new Percentile(source(), field(), new Literal(source(), (int) QuantileStates.MEDIAN, DataType.INTEGER));
+            : new Percentile(source(), field(), filter(), new Literal(source(), (int) QuantileStates.MEDIAN, DataType.INTEGER));
     }
 }

+ 5 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Min.java

@@ -153,7 +153,11 @@ public class Min extends AggregateFunction implements ToAggregator, SurrogateExp
     @Override
     public Expression surrogate() {
         if (field().dataType() == DataType.AGGREGATE_METRIC_DOUBLE) {
-            return new Min(source(), FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MIN));
+            return new Min(
+                source(),
+                FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MIN),
+                filter()
+            );
         }
         return field().foldable() ? new MvMin(source(), field()) : null;
     }

+ 6 - 2
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java

@@ -139,10 +139,14 @@ public class Sum extends NumericAggregate implements SurrogateExpression {
         var s = source();
         var field = field();
         if (field.dataType() == AGGREGATE_METRIC_DOUBLE) {
-            return new Sum(s, FromAggregateMetricDouble.withMetric(source(), field, AggregateMetricDoubleBlockBuilder.Metric.SUM));
+            return new Sum(
+                s,
+                FromAggregateMetricDouble.withMetric(source(), field, AggregateMetricDoubleBlockBuilder.Metric.SUM),
+                filter()
+            );
         }
 
         // SUM(const) is equivalent to MV_SUM(const)*COUNT(*).
-        return field.foldable() ? new Mul(s, new MvSum(s, field), new Count(s, Literal.keyword(s, StringUtils.WILDCARD))) : null;
+        return field.foldable() ? new Mul(s, new MvSum(s, field), new Count(s, Literal.keyword(s, StringUtils.WILDCARD), filter())) : null;
     }
 }

+ 2 - 2
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Top.java

@@ -218,9 +218,9 @@ public class Top extends AggregateFunction implements ToAggregator, SurrogateExp
 
         if (limitValue() == 1) {
             if (orderValue()) {
-                return new Min(s, field());
+                return new Min(s, field(), filter());
             } else {
-                return new Max(s, field());
+                return new Max(s, field(), filter());
             }
         }
 

+ 2 - 2
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvg.java

@@ -160,9 +160,9 @@ public class WeightedAvg extends AggregateFunction implements SurrogateExpressio
             return new MvAvg(s, field);
         }
         if (weight.foldable()) {
-            return new Div(s, new Sum(s, field), new Count(s, field), dataType());
+            return new Div(s, new Sum(s, field, filter()), new Count(s, field, filter()), dataType());
         } else {
-            return new Div(s, new Sum(s, new Mul(s, field, weight)), new Sum(s, weight), dataType());
+            return new Div(s, new Sum(s, new Mul(s, field, weight), filter()), new Sum(s, weight, filter()), dataType());
         }
     }
 

+ 23 - 0
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java

@@ -162,6 +162,29 @@ public abstract class AbstractAggregationTestCase extends AbstractFunctionTestCa
         }, this::evaluate);
     }
 
+    public void testSurrogateHasFilter() {
+        Expression expression = randomFrom(
+            buildLiteralExpression(testCase),
+            buildDeepCopyOfFieldExpression(testCase),
+            buildFieldExpression(testCase)
+        );
+
+        assumeTrue("expression should have no type errors", expression.typeResolved().resolved());
+
+        if (expression instanceof AggregateFunction && expression instanceof SurrogateExpression) {
+            var filter = ((AggregateFunction) expression).filter();
+
+            var surrogate = ((SurrogateExpression) expression).surrogate();
+
+            if (surrogate != null) {
+                surrogate.forEachDown(AggregateFunction.class, child -> {
+                    var surrogateFilter = child.filter();
+                    assertEquals(filter, surrogateFilter);
+                });
+            }
+        }
+    }
+
     private void aggregateSingleMode(Expression expression) {
         Object result;
         try (var aggregator = aggregator(expression, initialInputChannels(), AggregatorMode.SINGLE)) {