浏览代码

ESQL: Nested expressions inside stats command (#104387)

Allow nested expressions to be used both for grouping or inside
 aggregate functions inside the stats command.
As such the grammar has been tweaked to allow the stats group to have 
 optional aliasing.
As part of this fix, preserve the original field declaration (including 
 spaces) for implicit aliases.
Improve validation for incorrect aggregate function use (as arguments,
 grouping or inside evals).

Fix #99828
Costin Leau 1 年之前
父节点
当前提交
607185b280
共有 22 个文件被更改,包括 565 次插入290 次删除
  1. 6 0
      docs/changelog/104387.yaml
  2. 4 4
      docs/reference/esql/multivalued-fields.asciidoc
  3. 3 3
      x-pack/plugin/esql/qa/testFixtures/src/main/resources/eval.csv-spec
  4. 2 2
      x-pack/plugin/esql/qa/testFixtures/src/main/resources/null.csv-spec
  5. 57 0
      x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec
  6. 2 2
      x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_count_distinct.csv-spec
  7. 2 2
      x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_percentile.csv-spec
  8. 2 6
      x-pack/plugin/esql/src/main/antlr/EsqlBaseParser.g4
  9. 29 24
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java
  10. 105 1
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java
  11. 0 1
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParser.interp
  12. 182 178
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParser.java
  13. 0 12
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserBaseListener.java
  14. 0 7
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserBaseVisitor.java
  15. 0 10
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserListener.java
  16. 0 6
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserVisitor.java
  17. 43 5
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java
  18. 6 10
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java
  19. 13 6
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java
  20. 98 0
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java
  21. 1 1
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java
  22. 10 10
      x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/Expressions.java

+ 6 - 0
docs/changelog/104387.yaml

@@ -0,0 +1,6 @@
+pr: 104387
+summary: "ESQL: Nested expressions inside stats command"
+area: ES|QL
+type: enhancement
+issues:
+ - 99828

+ 4 - 4
docs/reference/esql/multivalued-fields.asciidoc

@@ -201,8 +201,8 @@ POST /_query
   "columns": [
     { "name": "a",   "type": "long"},
     { "name": "b",   "type": "long"},
-    { "name": "b+2", "type": "long"},
-    { "name": "a+b", "type": "long"}
+    { "name": "b + 2", "type": "long"},
+    { "name": "a + b", "type": "long"}
   ],
   "values": [
     [1, [1, 2], null, null],
@@ -236,8 +236,8 @@ POST /_query
   "columns": [
     { "name": "a",   "type": "long"},
     { "name": "b",   "type": "long"},
-    { "name": "b+2", "type": "long"},
-    { "name": "a+b", "type": "long"}
+    { "name": "b + 2", "type": "long"},
+    { "name": "a + b", "type": "long"}
   ],
   "values": [
     [1, 1, 3, 2],

+ 3 - 3
x-pack/plugin/esql/qa/testFixtures/src/main/resources/eval.csv-spec

@@ -319,7 +319,7 @@ Parto          |Bamford        |6.004230000000001
 // end::evalReplace-result[]
 ;
 
-docsEvalUnnamedColumn
+docsEvalUnnamedColumn#[skip:-8.12.99,reason:expression spaces are maintained since 8.13]
 // tag::evalUnnamedColumn[]
 FROM employees
 | SORT emp_no
@@ -329,7 +329,7 @@ FROM employees
 | LIMIT 3;
 
 // tag::evalUnnamedColumn-result[]
-first_name:keyword | last_name:keyword | height:double | height*3.281:double
+first_name:keyword | last_name:keyword | height:double | height * 3.281:double
 Georgi         |Facello        |2.03           |6.66043           
 Bezalel        |Simmel         |2.08           |6.82448           
 Parto          |Bamford        |1.83           |6.004230000000001
@@ -348,4 +348,4 @@ FROM employees
 avg_height_feet:double
 5.801464200000001
 // end::evalUnnamedColumnStats-result[]
-;
+;

+ 2 - 2
x-pack/plugin/esql/qa/testFixtures/src/main/resources/null.csv-spec

@@ -56,7 +56,7 @@ COUNT(emp_no):long
 // end::is-not-null-result[]
 ;
 
-coalesceSimple
+coalesceSimple#[skip:-8.12.99,reason:expression spaces are maintained since 8.13]
 // tag::coalesce[]
 ROW a=null, b="b"
 | EVAL COALESCE(a, b)
@@ -64,7 +64,7 @@ ROW a=null, b="b"
 ;
 
 // tag::coalesce-result[]
-a:null | b:keyword | COALESCE(a,b):keyword
+a:null | b:keyword | COALESCE(a, b):keyword
   null |         b | b
 // end::coalesce-result[]
 ;

+ 57 - 0
x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec

@@ -879,3 +879,60 @@ AVG(salary):double | avg_salary_rounded:double
 48248.55           | 48249.0
 // end::statsUnnamedColumnEval-result[]
 ;
+
+nestedExpressionNoGrouping#[skip:-8.12.99,reason:StatsNestedExp breaks bwc]
+FROM employees
+| STATS s = SUM(emp_no + 3), c = COUNT(emp_no)
+;
+
+s: long | c: long
+1005350 | 100
+;
+
+nestedExpressionInSurrogateAgg#[skip:-8.12.99,reason:StatsNestedExp breaks bwc]
+FROM employees
+| STATS a = AVG(emp_no % 5), s = SUM(emp_no % 5), c = COUNT(emp_no % 5)
+;
+
+a:double | s:long | c:long
+2.0      | 200    | 100
+;
+
+nestedExpressionInGroupingWithAlias#[skip:-8.12.99,reason:StatsNestedExp breaks bwc]
+FROM employees
+| STATS s = SUM(emp_no % 5), c = COUNT(emp_no % 5) BY l = languages + 20
+| SORT l
+;
+
+s:long | c:long | l : i
+39     | 15     | 21  
+36     | 19     | 22  
+30     | 17     | 23  
+32     | 18     | 24  
+43     | 21     | 25  
+20     | 10     | null
+;
+
+nestedMultiExpressionInGroupingsAndAggs#[skip:-8.12.99,reason:StatsNestedExp breaks bwc]
+FROM employees 
+| EVAL sal = salary + 10000 
+| STATS sum(sal), sum(salary + 10000) BY left(first_name, 1), concat(gender,   to_string(languages))
+| SORT `left(first_name, 1)`, `concat(gender,   to_string(languages))`
+| LIMIT 5
+;
+
+sum(sal):l | sum(salary + 10000):l | left(first_name, 1):s  | concat(gender,   to_string(languages)):s
+54307      | 54307                  |  A                    | F2
+70335      | 70335                  |  A                    | F3
+76817      | 76817                  |  A                    | F5
+123675     | 123675                 |  A                    | M3
+43370      | 43370                  |  B                    | F2
+;
+
+
+
+
+
+
+
+

+ 2 - 2
x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_count_distinct.csv-spec

@@ -87,7 +87,7 @@ COUNT_DISTINCT(ip0):long | COUNT_DISTINCT(ip1):long
 // end::count-distinct-result[]
 ;
 
-countDistinctOfIpPrecision
+countDistinctOfIpPrecision#[skip:-8.12.99,reason:expression spaces are maintained since 8.13]
 // tag::count-distinct-precision[]
 FROM hosts
 | STATS COUNT_DISTINCT(ip0, 80000), COUNT_DISTINCT(ip1, 5)
@@ -95,7 +95,7 @@ FROM hosts
 ;
 
 // tag::count-distinct-precision-result[]
-COUNT_DISTINCT(ip0,80000):long | COUNT_DISTINCT(ip1,5):long
+COUNT_DISTINCT(ip0, 80000):long | COUNT_DISTINCT(ip1, 5):long
 7                              | 9
 // end::count-distinct-precision-result[]
 ;

+ 2 - 2
x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_percentile.csv-spec

@@ -77,7 +77,7 @@ m:double   | p50:double
 0          | 0 
 ;
 
-medianOfInteger#[skip:-8.11.99,reason:ReplaceDuplicateAggWithEval breaks bwc gh-103765]
+medianOfInteger#[skip:-8.12.99,reason:ReplaceDuplicateAggWithEval breaks bwc gh-103765/Expression spaces are maintained since 8.13]
 // tag::median[]
 FROM employees
 | STATS MEDIAN(salary), PERCENTILE(salary, 50)
@@ -85,7 +85,7 @@ FROM employees
 ;
 
 // tag::median-result[]
-MEDIAN(salary):double | PERCENTILE(salary,50):double
+MEDIAN(salary):double | PERCENTILE(salary, 50):double
 47003                 | 47003    
 // end::median-result[]
 ;

+ 2 - 6
x-pack/plugin/esql/src/main/antlr/EsqlBaseParser.g4

@@ -111,15 +111,11 @@ evalCommand
     ;
 
 statsCommand
-    : STATS fields? (BY grouping)?
+    : STATS stats=fields? (BY grouping=fields)?
     ;
 
 inlinestatsCommand
-    : INLINESTATS fields (BY grouping)?
-    ;
-
-grouping
-    : qualifiedName (COMMA qualifiedName)*
+    : INLINESTATS stats=fields (BY grouping=fields)?
     ;
 
 fromIdentifier

+ 29 - 24
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java

@@ -21,11 +21,8 @@ import org.elasticsearch.xpack.ql.capabilities.Unresolvable;
 import org.elasticsearch.xpack.ql.common.Failure;
 import org.elasticsearch.xpack.ql.expression.Alias;
 import org.elasticsearch.xpack.ql.expression.Expression;
-import org.elasticsearch.xpack.ql.expression.FieldAttribute;
-import org.elasticsearch.xpack.ql.expression.Literal;
-import org.elasticsearch.xpack.ql.expression.MetadataAttribute;
+import org.elasticsearch.xpack.ql.expression.Expressions;
 import org.elasticsearch.xpack.ql.expression.NamedExpression;
-import org.elasticsearch.xpack.ql.expression.ReferenceAttribute;
 import org.elasticsearch.xpack.ql.expression.TypeResolutions;
 import org.elasticsearch.xpack.ql.expression.UnresolvedAttribute;
 import org.elasticsearch.xpack.ql.expression.function.aggregate.AggregateFunction;
@@ -150,36 +147,39 @@ public class Verifier {
 
     private static void checkAggregate(LogicalPlan p, Set<Failure> failures) {
         if (p instanceof Aggregate agg) {
+            // check aggregates
             agg.aggregates().forEach(e -> {
-                var exp = e instanceof Alias ? ((Alias) e).child() : e;
-                if (exp instanceof AggregateFunction aggFunc) {
-                    Expression field = aggFunc.field();
-
-                    // TODO: allow an expression?
-                    if ((field instanceof FieldAttribute
-                        || field instanceof MetadataAttribute
-                        || field instanceof ReferenceAttribute
-                        || field instanceof Literal) == false) {
+                var exp = e instanceof Alias a ? a.child() : e;
+                if (exp instanceof AggregateFunction af) {
+                    af.field().forEachDown(AggregateFunction.class, f -> {
+                        failures.add(fail(f, "nested aggregations [{}] not allowed inside other aggregations [{}]", f, af));
+                    });
+                } else {
+                    if (Expressions.match(agg.groupings(), g -> {
+                        Expression to = g instanceof Alias al ? al.child() : g;
+                        return to.semanticEquals(exp);
+                    }) == false) {
                         failures.add(
                             fail(
-                                e,
-                                "aggregate function's field must be an attribute or literal; found ["
-                                    + field.sourceText()
+                                exp,
+                                "expected an aggregate function or group but got ["
+                                    + exp.sourceText()
                                     + "] of type ["
-                                    + field.nodeName()
+                                    + exp.nodeName()
                                     + "]"
                             )
                         );
                     }
-                } else if (agg.groupings().contains(exp) == false) { // TODO: allow an expression?
-                    failures.add(
-                        fail(
-                            exp,
-                            "expected an aggregate function or group but got [" + exp.sourceText() + "] of type [" + exp.nodeName() + "]"
-                        )
-                    );
                 }
             });
+
+            // check grouping
+            // The grouping can not be an aggregate function
+            agg.groupings().forEach(e -> e.forEachUp(g -> {
+                if (g instanceof AggregateFunction af) {
+                    failures.add(fail(g, "cannot use an aggregate [{}] for grouping", af));
+                }
+            }));
         }
     }
 
@@ -214,12 +214,17 @@ public class Verifier {
     private static void checkEvalFields(LogicalPlan p, Set<Failure> failures) {
         if (p instanceof Eval eval) {
             eval.fields().forEach(field -> {
+                // check supported types
                 DataType dataType = field.dataType();
                 if (EsqlDataTypes.isRepresentable(dataType) == false) {
                     failures.add(
                         fail(field, "EVAL does not support type [{}] in expression [{}]", dataType.typeName(), field.child().sourceText())
                     );
                 }
+                // check no aggregate functions are used
+                field.forEachDown(AggregateFunction.class, af -> {
+                    failures.add(fail(af, "aggregate function [{}] not allowed outside STATS command", af.sourceText()));
+                });
             });
         }
     }

+ 105 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java

@@ -153,6 +153,7 @@ public class LogicalPlanOptimizer extends ParameterizedRuleExecutor<LogicalPlan,
             Limiter.ONCE,
             new SubstituteSurrogates(),
             new ReplaceRegexMatch(),
+            new ReplaceNestedExpressionWithEval(),
             new ReplaceAliasingEvalWithProject()
             // new NormalizeAggregate(), - waits on https://github.com/elastic/elasticsearch/issues/100634
         );
@@ -245,7 +246,7 @@ public class LogicalPlanOptimizer extends ParameterizedRuleExecutor<LogicalPlan,
             return plan;
         }
 
-        private static String temporaryName(NamedExpression agg, AggregateFunction af) {
+        static String temporaryName(NamedExpression agg, AggregateFunction af) {
             return "__" + agg.name() + "_" + af.functionName() + "@" + Integer.toHexString(af.hashCode());
         }
     }
@@ -1056,6 +1057,109 @@ public class LogicalPlanOptimizer extends ParameterizedRuleExecutor<LogicalPlan,
         }
     }
 
+    /**
+     * Replace nested expressions inside an aggregate with synthetic eval (which end up being projected away by the aggregate).
+     * stats sum(a + 1) by x % 2
+     * becomes
+     * eval `a + 1` = a + 1, `x % 2` = x % 2 | stats sum(`a+1`_ref) by `x % 2`_ref
+     */
+    static class ReplaceNestedExpressionWithEval extends OptimizerRules.OptimizerRule<Aggregate> {
+
+        @Override
+        protected LogicalPlan rule(Aggregate aggregate) {
+            List<Alias> evals = new ArrayList<>();
+            Map<String, Attribute> evalNames = new HashMap<>();
+            List<Expression> newGroupings = new ArrayList<>(aggregate.groupings());
+            boolean groupingChanged = false;
+
+            // start with the groupings since the aggs might duplicate it
+            for (int i = 0, s = newGroupings.size(); i < s; i++) {
+                Expression g = newGroupings.get(i);
+                // move the alias into an eval and replace it with its attribute
+                if (g instanceof Alias as) {
+                    groupingChanged = true;
+                    var attr = as.toAttribute();
+                    evals.add(as);
+                    evalNames.put(as.name(), attr);
+                    newGroupings.set(i, attr);
+                }
+            }
+
+            Holder<Boolean> aggsChanged = new Holder<>(false);
+            List<? extends NamedExpression> aggs = aggregate.aggregates();
+            List<NamedExpression> newAggs = new ArrayList<>(aggs.size());
+
+            // map to track common expressions
+            Map<Expression, Attribute> expToAttribute = new HashMap<>();
+            for (Alias a : evals) {
+                expToAttribute.put(a.child().canonical(), a.toAttribute());
+            }
+
+            // for the aggs make sure to unwrap the agg function and check the existing groupings
+            for (int i = 0, s = aggs.size(); i < s; i++) {
+                NamedExpression agg = aggs.get(i);
+
+                NamedExpression a = (NamedExpression) agg.transformDown(Alias.class, as -> {
+                    // if the child a nested expression
+                    Expression child = as.child();
+
+                    // shortcut for common scenario
+                    if (child instanceof AggregateFunction af && af.field() instanceof Attribute) {
+                        return as;
+                    }
+
+                    // check if the alias matches any from grouping otherwise unwrap it
+                    Attribute ref = evalNames.get(as.name());
+                    if (ref != null) {
+                        aggsChanged.set(true);
+                        return ref;
+                    }
+
+                    // TODO: break expression into aggregate functions (sum(x + 1) / max(y + 2))
+                    // List<Expression> afs = a.collectFirstChildren(AggregateFunction.class::isInstance);
+
+                    // 1. look for the aggregate function
+                    var replaced = child.transformUp(AggregateFunction.class, af -> {
+                        Expression result = af;
+
+                        Expression field = af.field();
+                        // 2. if the field is a nested expression (not attribute or literal), replace it
+                        if (field instanceof Attribute == false && field.foldable() == false) {
+                            // 3. create a new alias if one doesn't exist yet no reference
+                            Attribute attr = expToAttribute.computeIfAbsent(field.canonical(), k -> {
+                                Alias newAlias = new Alias(k.source(), temporaryName(agg, af), null, k, null, true);
+                                evals.add(newAlias);
+                                aggsChanged.set(true);
+                                return newAlias.toAttribute();
+                            });
+                            // replace field with attribute
+                            result = af.replaceChildren(Collections.singletonList(attr));
+                        }
+                        return result;
+                    });
+
+                    return as.replaceChild(replaced);
+                });
+
+                newAggs.add(a);
+            }
+
+            if (evals.size() > 0) {
+                var groupings = groupingChanged ? newGroupings : aggregate.groupings();
+                var aggregates = aggsChanged.get() ? newAggs : aggregate.aggregates();
+
+                var newEval = new Eval(aggregate.source(), aggregate.child(), evals);
+                aggregate = new Aggregate(aggregate.source(), newEval, groupings, aggregates);
+            }
+
+            return aggregate;
+        }
+
+        static String temporaryName(NamedExpression agg, AggregateFunction af) {
+            return SubstituteSurrogates.temporaryName(agg, af);
+        }
+    }
+
     /**
      * Replace aliasing evals (eval x=a) with a projection which can be further combined / simplified.
      * The rule gets applied only if there's another project (Project/Stats) above it.

文件差异内容过多而无法显示
+ 0 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParser.interp


文件差异内容过多而无法显示
+ 182 - 178
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParser.java


+ 0 - 12
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserBaseListener.java

@@ -384,18 +384,6 @@ public class EsqlBaseParserBaseListener implements EsqlBaseParserListener {
    * <p>The default implementation does nothing.</p>
    */
   @Override public void exitInlinestatsCommand(EsqlBaseParser.InlinestatsCommandContext ctx) { }
-  /**
-   * {@inheritDoc}
-   *
-   * <p>The default implementation does nothing.</p>
-   */
-  @Override public void enterGrouping(EsqlBaseParser.GroupingContext ctx) { }
-  /**
-   * {@inheritDoc}
-   *
-   * <p>The default implementation does nothing.</p>
-   */
-  @Override public void exitGrouping(EsqlBaseParser.GroupingContext ctx) { }
   /**
    * {@inheritDoc}
    *

+ 0 - 7
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserBaseVisitor.java

@@ -229,13 +229,6 @@ public class EsqlBaseParserBaseVisitor<T> extends AbstractParseTreeVisitor<T> im
    * {@link #visitChildren} on {@code ctx}.</p>
    */
   @Override public T visitInlinestatsCommand(EsqlBaseParser.InlinestatsCommandContext ctx) { return visitChildren(ctx); }
-  /**
-   * {@inheritDoc}
-   *
-   * <p>The default implementation returns the result of calling
-   * {@link #visitChildren} on {@code ctx}.</p>
-   */
-  @Override public T visitGrouping(EsqlBaseParser.GroupingContext ctx) { return visitChildren(ctx); }
   /**
    * {@inheritDoc}
    *

+ 0 - 10
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserListener.java

@@ -351,16 +351,6 @@ public interface EsqlBaseParserListener extends ParseTreeListener {
    * @param ctx the parse tree
    */
   void exitInlinestatsCommand(EsqlBaseParser.InlinestatsCommandContext ctx);
-  /**
-   * Enter a parse tree produced by {@link EsqlBaseParser#grouping}.
-   * @param ctx the parse tree
-   */
-  void enterGrouping(EsqlBaseParser.GroupingContext ctx);
-  /**
-   * Exit a parse tree produced by {@link EsqlBaseParser#grouping}.
-   * @param ctx the parse tree
-   */
-  void exitGrouping(EsqlBaseParser.GroupingContext ctx);
   /**
    * Enter a parse tree produced by {@link EsqlBaseParser#fromIdentifier}.
    * @param ctx the parse tree

+ 0 - 6
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserVisitor.java

@@ -213,12 +213,6 @@ public interface EsqlBaseParserVisitor<T> extends ParseTreeVisitor<T> {
    * @return the visitor result
    */
   T visitInlinestatsCommand(EsqlBaseParser.InlinestatsCommandContext ctx);
-  /**
-   * Visit a parse tree produced by {@link EsqlBaseParser#grouping}.
-   * @param ctx the parse tree
-   * @return the visitor result
-   */
-  T visitGrouping(EsqlBaseParser.GroupingContext ctx);
   /**
    * Visit a parse tree produced by {@link EsqlBaseParser#fromIdentifier}.
    * @param ctx the parse tree

+ 43 - 5
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java

@@ -33,6 +33,7 @@ import org.elasticsearch.xpack.esql.type.EsqlDataTypes;
 import org.elasticsearch.xpack.ql.InvalidArgumentException;
 import org.elasticsearch.xpack.ql.QlIllegalArgumentException;
 import org.elasticsearch.xpack.ql.expression.Alias;
+import org.elasticsearch.xpack.ql.expression.Attribute;
 import org.elasticsearch.xpack.ql.expression.Expression;
 import org.elasticsearch.xpack.ql.expression.Literal;
 import org.elasticsearch.xpack.ql.expression.NamedExpression;
@@ -58,12 +59,12 @@ import java.math.BigInteger;
 import java.time.Duration;
 import java.time.ZoneId;
 import java.time.temporal.TemporalAmount;
+import java.util.ArrayList;
 import java.util.List;
 import java.util.Locale;
 import java.util.Map;
 import java.util.function.BiFunction;
 
-import static java.util.Collections.emptyList;
 import static java.util.Collections.singletonList;
 import static org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter.parseTemporalAmout;
 import static org.elasticsearch.xpack.esql.type.EsqlDataTypes.DATE_PERIOD;
@@ -409,13 +410,50 @@ public abstract class ExpressionBuilder extends IdentifierBuilder {
     public Alias visitField(EsqlBaseParser.FieldContext ctx) {
         UnresolvedAttribute id = visitQualifiedName(ctx.qualifiedName());
         Expression value = expression(ctx.booleanExpression());
-        String name = id == null ? ctx.getText() : id.qualifiedName();
-        return new Alias(source(ctx), name, value);
+        var source = source(ctx);
+        String name = id == null ? source.text() : id.qualifiedName();
+        return new Alias(source, name, value);
     }
 
     @Override
-    public List<NamedExpression> visitGrouping(EsqlBaseParser.GroupingContext ctx) {
-        return ctx != null ? visitList(this, ctx.qualifiedName(), NamedExpression.class) : emptyList();
+    public List<Alias> visitFields(EsqlBaseParser.FieldsContext ctx) {
+        return ctx != null ? visitList(this, ctx.field(), Alias.class) : new ArrayList<>();
+    }
+
+    /**
+     * Similar to {@link #visitFields(EsqlBaseParser.FieldsContext)} however avoids wrapping the exception
+     * into an Alias.
+     */
+    public List<NamedExpression> visitGrouping(EsqlBaseParser.FieldsContext ctx) {
+        List<NamedExpression> list;
+        if (ctx != null) {
+            var fields = ctx.field();
+            list = new ArrayList<>(fields.size());
+            for (EsqlBaseParser.FieldContext field : fields) {
+                NamedExpression ne = null;
+                UnresolvedAttribute id = visitQualifiedName(field.qualifiedName());
+                Expression value = expression(field.booleanExpression());
+                String name = null;
+                if (id == null) {
+                    // when no alias has been specified, see if the underling one can be reused
+                    if (value instanceof Attribute a) {
+                        ne = a;
+                    } else {
+                        name = source(field).text();
+                    }
+                } else {
+                    name = id.qualifiedName();
+                }
+                // wrap when necessary - no alias and no underlying attribute
+                if (ne == null) {
+                    ne = new Alias(source(ctx), name, value);
+                }
+                list.add(ne);
+            }
+        } else {
+            list = new ArrayList<>();
+        }
+        return list;
     }
 
     @Override

+ 6 - 10
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java

@@ -53,6 +53,7 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.HashMap;
 import java.util.LinkedHashMap;
+import java.util.LinkedHashSet;
 import java.util.List;
 import java.util.Locale;
 import java.util.Map;
@@ -197,14 +198,14 @@ public class LogicalPlanBuilder extends ExpressionBuilder {
 
     @Override
     public PlanFactory visitStatsCommand(EsqlBaseParser.StatsCommandContext ctx) {
-        List<NamedExpression> aggregates = new ArrayList<>(visitFields(ctx.fields()));
-        List<NamedExpression> groupings = visitGrouping(ctx.grouping());
+        List<NamedExpression> aggregates = new ArrayList<>(visitFields(ctx.stats));
+        List<NamedExpression> groupings = visitGrouping(ctx.grouping);
         if (aggregates.isEmpty() && groupings.isEmpty()) {
             throw new ParsingException(source(ctx), "At least one aggregation or grouping expression required in [{}]", ctx.getText());
         }
         // grouping keys are automatically added as aggregations however the user is not allowed to specify them
         if (groupings.isEmpty() == false && aggregates.isEmpty() == false) {
-            var groupNames = Expressions.names(groupings);
+            var groupNames = new LinkedHashSet<>(Expressions.names(Expressions.references(groupings)));
 
             for (NamedExpression aggregate : aggregates) {
                 if (aggregate instanceof Alias a && a.child() instanceof UnresolvedAttribute ua && groupNames.contains(ua.name())) {
@@ -218,8 +219,8 @@ public class LogicalPlanBuilder extends ExpressionBuilder {
 
     @Override
     public PlanFactory visitInlinestatsCommand(EsqlBaseParser.InlinestatsCommandContext ctx) {
-        List<NamedExpression> aggregates = new ArrayList<>(visitFields(ctx.fields()));
-        List<NamedExpression> groupings = visitGrouping(ctx.grouping());
+        List<NamedExpression> aggregates = new ArrayList<>(visitFields(ctx.stats));
+        List<NamedExpression> groupings = visitGrouping(ctx.grouping);
         aggregates.addAll(groupings);
         return input -> new InlineStats(source(ctx), input, new ArrayList<>(groupings), aggregates);
     }
@@ -230,11 +231,6 @@ public class LogicalPlanBuilder extends ExpressionBuilder {
         return input -> new Filter(source(ctx), input, expression);
     }
 
-    @Override
-    public List<Alias> visitFields(EsqlBaseParser.FieldsContext ctx) {
-        return ctx != null ? visitList(this, ctx.field(), Alias.class) : new ArrayList<>();
-    }
-
     @Override
     public PlanFactory visitLimitCommand(EsqlBaseParser.LimitCommandContext ctx) {
         Source source = source(ctx);

+ 13 - 6
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java

@@ -70,17 +70,13 @@ public class VerifierTests extends ESTestCase {
             error("from test | stats length(first_name), count(1) by first_name")
         );
         assertEquals(
-            "1:19: aggregate function's field must be an attribute or literal; found [emp_no / 2] of type [Div]",
-            error("from test | stats x = avg(emp_no / 2) by emp_no")
+            "1:23: nested aggregations [max(salary)] not allowed inside other aggregations [max(max(salary))]",
+            error("from test | stats max(max(salary)) by first_name")
         );
         assertEquals(
             "1:25: argument of [avg(first_name)] must be [numeric], found value [first_name] type [keyword]",
             error("from test | stats count(avg(first_name)) by first_name")
         );
-        assertEquals(
-            "1:19: aggregate function's field must be an attribute or literal; found [length(first_name)] of type [Length]",
-            error("from test | stats count(length(first_name)) by first_name")
-        );
         assertEquals(
             "1:23: expected an aggregate function or group but got [emp_no + avg(emp_no)] of type [Add]",
             error("from test | stats x = emp_no + avg(emp_no) by emp_no")
@@ -95,6 +91,17 @@ public class VerifierTests extends ESTestCase {
         );
     }
 
+    public void testAggsInsideGrouping() {
+        assertEquals(
+            "1:36: cannot use an aggregate [max(languages)] for grouping",
+            error("from test| stats max(languages) by max(languages)")
+        );
+    }
+
+    public void testAggsInsideEval() throws Exception {
+        assertEquals("1:29: aggregate function [max(b)] not allowed outside STATS command", error("row a = 1, b = 2 | eval x = max(b)"));
+    }
+
     public void testDoubleRenamingField() {
         assertEquals(
             "1:44: Column [emp_no] renamed to [r1] and is no longer available [emp_no as r3]",

+ 98 - 0
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java

@@ -35,6 +35,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.math.Pow;
 import org.elasticsearch.xpack.esql.expression.function.scalar.math.Round;
 import org.elasticsearch.xpack.esql.expression.function.scalar.string.Substring;
 import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add;
+import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mod;
 import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul;
 import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In;
 import org.elasticsearch.xpack.esql.parser.EsqlParser;
@@ -95,6 +96,7 @@ import static org.elasticsearch.xpack.ql.TestUtils.relation;
 import static org.elasticsearch.xpack.ql.tree.Source.EMPTY;
 import static org.elasticsearch.xpack.ql.type.DataTypes.INTEGER;
 import static org.hamcrest.Matchers.contains;
+import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.empty;
 import static org.hamcrest.Matchers.emptyArray;
@@ -2760,6 +2762,102 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
         var from = as(eval.child(), EsRelation.class);
     }
 
+    /**
+     * Expects
+     * Limit[500[INTEGER]]
+     * \_Aggregate[[emp_no%2{r}#6],[COUNT(salary{f}#12) AS c, emp_no%2{r}#6]]
+     *   \_Eval[[emp_no{f}#7 % 2[INTEGER] AS emp_no%2]]
+     *     \_EsRelation[test][_meta_field{f}#13, emp_no{f}#7, first_name{f}#8, ge..]
+     */
+    public void testNestedExpressionsInGroups() {
+        var plan = optimizedPlan("""
+            from test
+            | stats c = count(salary) by emp_no % 2
+            """);
+
+        var limit = as(plan, Limit.class);
+        var agg = as(limit.child(), Aggregate.class);
+        var groupings = agg.groupings();
+        var aggs = agg.aggregates();
+        var ref = as(groupings.get(0), ReferenceAttribute.class);
+        assertThat(aggs.get(1), is(ref));
+        var eval = as(agg.child(), Eval.class);
+        assertThat(eval.fields(), hasSize(1));
+        assertThat(eval.fields().get(0).toAttribute(), is(ref));
+        assertThat(eval.fields().get(0).name(), is("emp_no % 2"));
+    }
+
+    /**
+     * Expects
+     * Limit[500[INTEGER]]
+     * \_Aggregate[[emp_no{f}#6],[COUNT(__c_COUNT@1bd45f36{r}#16) AS c, emp_no{f}#6]]
+     *   \_Eval[[salary{f}#11 + 1[INTEGER] AS __c_COUNT@1bd45f36]]
+     *     \_EsRelation[test][_meta_field{f}#12, emp_no{f}#6, first_name{f}#7, ge..]
+     */
+    public void testNestedExpressionsInAggs() {
+        var plan = optimizedPlan("""
+            from test
+            | stats c = count(salary + 1) by emp_no
+            """);
+
+        var limit = as(plan, Limit.class);
+        var agg = as(limit.child(), Aggregate.class);
+        var aggs = agg.aggregates();
+        var count = aliased(aggs.get(0), Count.class);
+        var ref = as(count.field(), ReferenceAttribute.class);
+        var eval = as(agg.child(), Eval.class);
+        var fields = eval.fields();
+        assertThat(fields, hasSize(1));
+        assertThat(fields.get(0).toAttribute(), is(ref));
+        var add = aliased(fields.get(0), Add.class);
+        assertThat(Expressions.name(add.left()), is("salary"));
+    }
+
+    /**
+     * Limit[500[INTEGER]]
+     * \_Aggregate[[emp_no%2{r}#7],[COUNT(__c_COUNT@fb7855b0{r}#18) AS c, emp_no%2{r}#7]]
+     *   \_Eval[[emp_no{f}#8 % 2[INTEGER] AS emp_no%2, 100[INTEGER] / languages{f}#11 + salary{f}#13 + 1[INTEGER] AS __c_COUNT
+     * @fb7855b0]]
+     *     \_EsRelation[test][_meta_field{f}#14, emp_no{f}#8, first_name{f}#9, ge..]
+     */
+    public void testNestedExpressionsInBothAggsAndGroups() {
+        var plan = optimizedPlan("""
+            from test
+            | stats c = count(salary + 1 + 100 / languages) by emp_no % 2
+            """);
+
+        var limit = as(plan, Limit.class);
+        var agg = as(limit.child(), Aggregate.class);
+        var groupings = agg.groupings();
+        var aggs = agg.aggregates();
+        var gRef = as(groupings.get(0), ReferenceAttribute.class);
+        assertThat(aggs.get(1), is(gRef));
+
+        var count = aliased(aggs.get(0), Count.class);
+        var aggRef = as(count.field(), ReferenceAttribute.class);
+        var eval = as(agg.child(), Eval.class);
+        var fields = eval.fields();
+        assertThat(fields, hasSize(2));
+        assertThat(fields.get(0).toAttribute(), is(gRef));
+        assertThat(fields.get(1).toAttribute(), is(aggRef));
+
+        var mod = aliased(fields.get(0), Mod.class);
+        assertThat(Expressions.name(mod.left()), is("emp_no"));
+        var refs = Expressions.references(singletonList(fields.get(1)));
+        assertThat(Expressions.names(refs), containsInAnyOrder("languages", "salary"));
+    }
+
+    public void testNestedMultiExpressionsInGroupingAndAggs() {
+        var plan = optimizedPlan("""
+            from test
+            | stats count(salary + 1), max(salary   +  23) by languages   + 1, emp_no %  3
+            """);
+
+        var limit = as(plan, Limit.class);
+        var agg = as(limit.child(), Aggregate.class);
+        assertThat(Expressions.names(agg.output()), contains("count(salary + 1)", "max(salary   +  23)", "languages   + 1", "emp_no %  3"));
+    }
+
     private LogicalPlan optimizedPlan(String query) {
         return plan(query);
     }

+ 1 - 1
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java

@@ -227,7 +227,7 @@ public class StatementParserTests extends ESTestCase {
                 List.of(
                     new Alias(
                         EMPTY,
-                        "fn(a+1)",
+                        "fn(a + 1)",
                         new UnresolvedFunction(EMPTY, "fn", DEFAULT, List.of(new Add(EMPTY, attribute("a"), integer(1))))
                     )
                 )

+ 10 - 10
x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/Expressions.java

@@ -28,7 +28,7 @@ public final class Expressions {
     private Expressions() {}
 
     public static NamedExpression wrapAsNamed(Expression exp) {
-        return exp instanceof NamedExpression ? (NamedExpression) exp : new Alias(exp.source(), exp.sourceText(), exp);
+        return exp instanceof NamedExpression ne ? ne : new Alias(exp.source(), exp.sourceText(), exp);
     }
 
     public static List<Attribute> asAttributes(List<? extends NamedExpression> named) {
@@ -136,7 +136,7 @@ public final class Expressions {
     }
 
     public static String name(Expression e) {
-        return e instanceof NamedExpression ? ((NamedExpression) e).name() : e.sourceText();
+        return e instanceof NamedExpression ne ? ne.name() : e.sourceText();
     }
 
     public static boolean isNull(Expression e) {
@@ -153,8 +153,8 @@ public final class Expressions {
     }
 
     public static Attribute attribute(Expression e) {
-        if (e instanceof NamedExpression) {
-            return ((NamedExpression) e).toAttribute();
+        if (e instanceof NamedExpression ne) {
+            return ne.toAttribute();
         }
         return null;
     }
@@ -175,8 +175,8 @@ public final class Expressions {
         // an alias of same name and data type can be reused (by mistake): need to use a list to collect all refs (and later report them)
         List<Tuple<Attribute, Expression>> aliases = new ArrayList<>();
         for (NamedExpression ne : named) {
-            if (ne instanceof Alias) {
-                aliases.add(new Tuple<>(ne.toAttribute(), ((Alias) ne).child()));
+            if (ne instanceof Alias as) {
+                aliases.add(new Tuple<>(ne.toAttribute(), as.child()));
             }
         }
         return aliases;
@@ -218,11 +218,11 @@ public final class Expressions {
         if (e.foldable()) {
             return new ConstantInput(e.source(), e, e.fold());
         }
-        if (e instanceof NamedExpression) {
-            return new AttributeInput(e.source(), e, ((NamedExpression) e).toAttribute());
+        if (e instanceof NamedExpression ne) {
+            return new AttributeInput(e.source(), e, ne.toAttribute());
         }
-        if (e instanceof Function) {
-            return ((Function) e).asPipe();
+        if (e instanceof Function f) {
+            return f.asPipe();
         }
         throw new QlIllegalArgumentException("Cannot create pipe for {}", e);
     }

部分文件因为文件数量过多而无法显示