Browse Source

ESQL: Allow reusing BUCKET grouping expressions in aggs (#107578)

This fixes the queries reusing in STATS aggs part of expressions with
`BUCKET` declared in STATS grouping part.

Ex: `| STATS BUCKET(salary, 1000.) + 1 BY BUCKET(salary, 1000.)`

This was failing since the agg BUCKET's `salary` reference is no longer
available in the synthetic EVAL generated on top of the aggregation,
evaluating the "aggs" expression (the addition in the example above).
Bogdan Pintea 1 year ago
parent
commit
d15c59d423

+ 5 - 0
docs/changelog/107578.yaml

@@ -0,0 +1,5 @@
+pr: 107578
+summary: "ESQL: Allow reusing BUCKET grouping expressions in aggs"
+area: ES|QL
+type: bug
+issues: []

+ 68 - 1
x-pack/plugin/esql/qa/testFixtures/src/main/resources/bucket.csv-spec

@@ -415,13 +415,80 @@ FROM employees
 | LIMIT 4
 ;
 
- sumK:double   | b1k:double    | b2k:double      
+ sumK:double   | b1k:double    | b2k:double
 49.0           |25000.0        |24000.0
 52.0           |26000.0        |26000.0
 53.0           |27000.0        |26000.0
 56.0           |28000.0        |28000.0
 ;
 
+reuseGroupingFunction#[skip:-8.14.99, reason:BUCKET renamed in 8.14]
+FROM employees
+| STATS sum = 1 + BUCKET(salary, 1000.) BY b1k = BUCKET(salary, 1000.)
+| SORT sum
+| LIMIT 4
+;
+
+ sum:double    | b1k:double
+25001.0        |25000.0
+26001.0        |26000.0
+27001.0        |27000.0
+28001.0        |28000.0
+;
+
+reuseGroupingFunctionWithExpression#[skip:-8.14.99, reason:BUCKET renamed in 8.14]
+FROM employees
+| STATS sum = BUCKET(salary % 2 + 13, 1.) + 1 BY bucket = BUCKET(salary % 2 + 13, 1.)
+| SORT sum
+;
+
+ sum:double    | bucket:double
+14.0           |13.0
+15.0           |14.0
+;
+
+reuseGroupingFunctionWithinAggs#[skip:-8.14.99, reason:BUCKET renamed in 8.14]
+FROM employees
+| STATS sum = 1 + MAX(1 + BUCKET(salary, 1000.)) BY BUCKET(salary, 1000.) + 1
+| SORT sum
+| LIMIT 4
+;
+
+ sum:double    |BUCKET(salary, 1000.) + 1:double
+25002.0        |25001.0
+26002.0        |26001.0
+27002.0        |27001.0
+28002.0        |28001.0
+;
+
+reuseGroupingFunctionWithAggsExpression#[skip:-8.14.99, reason:BUCKET renamed in 8.14]
+FROM employees
+| STATS sum = 1 + AVG(BUCKET(salary, 1000.)) + BUCKET(salary, 1000.) BY bucket = BUCKET(salary, 1000.)
+| SORT sum
+| LIMIT 4
+;
+
+ sum:double    | bucket:double
+50001.0        |25000.0
+52001.0        |26000.0
+54001.0        |27000.0
+56001.0        |28000.0
+;
+
+reuseMultipleGroupingFunctionWithAggsExpression#[skip:-8.14.99, reason:BUCKET renamed in 8.14]
+FROM employees
+| STATS sum = b2k + AVG(BUCKET(salary, 1000.)) + BUCKET(salary, 1000.) BY b1k = BUCKET(salary, 1000.), b2k = BUCKET(salary, 2000.)
+| SORT sum
+| LIMIT 4
+;
+
+ sum:double    | b1k:double    | b2k:double
+74000.0        |25000.0        |24000.0
+78000.0        |26000.0        |26000.0
+80000.0        |27000.0        |26000.0
+84000.0        |28000.0        |28000.0
+;
+
 //
 // BIN copies
 //

+ 12 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java

@@ -18,6 +18,7 @@ import org.elasticsearch.xpack.esql.VerificationException;
 import org.elasticsearch.xpack.esql.evaluator.predicate.operator.comparison.Equals;
 import org.elasticsearch.xpack.esql.expression.SurrogateExpression;
 import org.elasticsearch.xpack.esql.expression.function.aggregate.Count;
+import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction;
 import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
 import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce;
 import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.SpatialRelatesFunction;
@@ -1308,6 +1309,7 @@ public class LogicalPlanOptimizer extends ParameterizedRuleExecutor<LogicalPlan,
         protected LogicalPlan rule(Aggregate aggregate) {
             List<Alias> evals = new ArrayList<>();
             Map<String, Attribute> evalNames = new HashMap<>();
+            Map<GroupingFunction, Attribute> groupingAttributes = new HashMap<>();
             List<Expression> newGroupings = new ArrayList<>(aggregate.groupings());
             boolean groupingChanged = false;
 
@@ -1321,6 +1323,9 @@ public class LogicalPlanOptimizer extends ParameterizedRuleExecutor<LogicalPlan,
                     evals.add(as);
                     evalNames.put(as.name(), attr);
                     newGroupings.set(i, attr);
+                    if (as.child() instanceof GroupingFunction gf) {
+                        groupingAttributes.put(gf, attr);
+                    }
                 }
             }
 
@@ -1374,6 +1379,13 @@ public class LogicalPlanOptimizer extends ParameterizedRuleExecutor<LogicalPlan,
                         }
                         return result;
                     });
+                    // replace any grouping functions with their references pointing to the added synthetic eval
+                    replaced = replaced.transformDown(GroupingFunction.class, gf -> {
+                        aggsChanged.set(true);
+                        // should never return null, as it's verified.
+                        // but even if broken, the transform will fail safely; otoh, returning `gf` will fail later due to incorrect plan.
+                        return groupingAttributes.get(gf);
+                    });
 
                     return as.replaceChild(replaced);
                 });

+ 38 - 0
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java

@@ -3348,6 +3348,44 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
         );
     }
 
+    /*
+     * Project[[bucket(salary, 1000.) + 1{r}#3, bucket(salary, 1000.){r}#5]]
+        \_Eval[[bucket(salary, 1000.){r}#5 + 1[INTEGER] AS bucket(salary, 1000.) + 1]]
+          \_Limit[1000[INTEGER]]
+            \_Aggregate[[bucket(salary, 1000.){r}#5],[bucket(salary, 1000.){r}#5]]
+              \_Eval[[BUCKET(salary{f}#12,1000.0[DOUBLE]) AS bucket(salary, 1000.)]]
+                \_EsRelation[test][_meta_field{f}#13, emp_no{f}#7, first_name{f}#8, ge..]
+     */
+    public void testBucketWithAggExpression() {
+        var plan = plan("""
+            from test
+            | stats bucket(salary, 1000.) + 1 by bucket(salary, 1000.)
+            """);
+        var project = as(plan, Project.class);
+        var evalTop = as(project.child(), Eval.class);
+        var limit = as(evalTop.child(), Limit.class);
+        var agg = as(limit.child(), Aggregate.class);
+        var evalBottom = as(agg.child(), Eval.class);
+        var relation = as(evalBottom.child(), EsRelation.class);
+
+        assertThat(evalTop.fields().size(), is(1));
+        assertThat(evalTop.fields().get(0), instanceOf(Alias.class));
+        assertThat(evalTop.fields().get(0).child(), instanceOf(Add.class));
+        var add = (Add) evalTop.fields().get(0).child();
+        assertThat(add.left(), instanceOf(ReferenceAttribute.class));
+        var ref = (ReferenceAttribute) add.left();
+
+        assertThat(evalBottom.fields().size(), is(1));
+        assertThat(evalBottom.fields().get(0), instanceOf(Alias.class));
+        var alias = evalBottom.fields().get(0);
+        assertEquals(ref, alias.toAttribute());
+
+        assertThat(agg.aggregates().size(), is(1));
+        assertThat(agg.aggregates().get(0), is(ref));
+        assertThat(agg.groupings().size(), is(1));
+        assertThat(agg.groupings().get(0), is(ref));
+    }
+
     /**
      * Expects
      * Project[[x{r}#5]]