Browse Source

ESQL: Make Categorize usable in aggs when identical to a grouping (#117835) (#117895)

Cases like `STATS MV_APPEND(cat, CATEGORIZE(x)) BY cat=CATEGORIZE(x)` should work, as they're moved to an EVAL by a rule.

Also, these cases were discarded, as they fail because of other verifications (Which also fail for BUCKET):
```
STATS x = category BY category=CATEGORIZE(message)
STATS x = CATEGORIZE(message) BY CATEGORIZE(message)
STATS x = CATEGORIZE(message) BY category=CATEGORIZE(message)
Iván Cea Fontenla 10 months ago
parent
commit
d62be0f1ce

+ 21 - 0
x-pack/plugin/esql/qa/testFixtures/src/main/resources/bucket.csv-spec

@@ -503,6 +503,27 @@ FROM employees
 //end::reuseGroupingFunctionWithExpression-result[]
 ;
 
+reuseGroupingFunctionImplicitAliasWithExpression#[skip:-8.13.99, reason:BUCKET renamed in 8.14]
+FROM employees
+| STATS s1 = `BUCKET(salary / 100 + 99, 50.)` + 1, s2 = BUCKET(salary / 1000 + 999, 50.) + 2 BY BUCKET(salary / 100 + 99, 50.), b2 = BUCKET(salary / 1000 + 999, 50.)
+| SORT `BUCKET(salary / 100 + 99, 50.)`, b2
+| KEEP s1, `BUCKET(salary / 100 + 99, 50.)`, s2, b2
+;
+
+ s1:double | BUCKET(salary / 100 + 99, 50.):double | s2:double   | b2:double
+351.0      |350.0      |1002.0       |1000.0
+401.0      |400.0      |1002.0       |1000.0
+451.0      |450.0      |1002.0       |1000.0
+501.0      |500.0      |1002.0       |1000.0
+551.0      |550.0      |1002.0       |1000.0
+601.0      |600.0      |1002.0       |1000.0
+601.0      |600.0      |1052.0       |1050.0
+651.0      |650.0      |1052.0       |1050.0
+701.0      |700.0      |1052.0       |1050.0
+751.0      |750.0      |1052.0       |1050.0
+801.0      |800.0      |1052.0       |1050.0
+;
+
 reuseGroupingFunctionWithinAggs#[skip:-8.13.99, reason:BUCKET renamed in 8.14]
 FROM employees
 | STATS sum = 1 + MAX(1 + BUCKET(salary, 1000.)) BY BUCKET(salary, 1000.) + 1

+ 83 - 38
x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec

@@ -1,5 +1,5 @@
 standard aggs
-required_capability: categorize_v4
+required_capability: categorize_v5
 
 FROM sample_data
   | STATS count=COUNT(),
@@ -17,7 +17,7 @@ count:long | sum:long |     avg:double     | count_distinct:long | category:keyw
 ;
 
 values aggs
-required_capability: categorize_v4
+required_capability: categorize_v5
 
 FROM sample_data
   | STATS values=MV_SORT(VALUES(message)),
@@ -33,7 +33,7 @@ values:keyword                                                        |      top
 ;
 
 mv
-required_capability: categorize_v4
+required_capability: categorize_v5
 
 FROM mv_sample_data
   | STATS COUNT(), SUM(event_duration) BY category=CATEGORIZE(message)
@@ -48,7 +48,7 @@ COUNT():long | SUM(event_duration):long | category:keyword
 ;
 
 row mv
-required_capability: categorize_v4
+required_capability: categorize_v5
 
 ROW message = ["connected to a", "connected to b", "disconnected"], str = ["a", "b", "c"]
   | STATS COUNT(), VALUES(str) BY category=CATEGORIZE(message)
@@ -61,7 +61,7 @@ COUNT():long | VALUES(str):keyword | category:keyword
 ;
 
 skips stopwords
-required_capability: categorize_v4
+required_capability: categorize_v5
 
 ROW message = ["Mon Tue connected to a", "Jul Aug connected to b September ", "UTC connected GMT to c UTC"]
   | STATS COUNT() BY category=CATEGORIZE(message)
@@ -73,7 +73,7 @@ COUNT():long | category:keyword
 ;
 
 with multiple indices
-required_capability: categorize_v4
+required_capability: categorize_v5
 required_capability: union_types
 
 FROM sample_data*
@@ -88,7 +88,7 @@ COUNT():long | category:keyword
 ;
 
 mv with many values
-required_capability: categorize_v4
+required_capability: categorize_v5
 
 FROM employees
   | STATS COUNT() BY category=CATEGORIZE(job_positions)
@@ -105,7 +105,7 @@ COUNT():long | category:keyword
 ;
 
 mv with many values and SUM
-required_capability: categorize_v4
+required_capability: categorize_v5
 
 FROM employees
   | STATS SUM(languages) BY category=CATEGORIZE(job_positions)
@@ -120,7 +120,7 @@ SUM(languages):long | category:keyword
 ;
 
 mv with many values and nulls and SUM
-required_capability: categorize_v4
+required_capability: categorize_v5
 
 FROM employees
   | STATS SUM(languages) BY category=CATEGORIZE(job_positions)
@@ -134,7 +134,7 @@ SUM(languages):long | category:keyword
 ;
 
 mv via eval
-required_capability: categorize_v4
+required_capability: categorize_v5
 
 FROM sample_data
   | EVAL message = MV_APPEND(message, "Banana")
@@ -150,7 +150,7 @@ COUNT():long | category:keyword
 ;
 
 mv via eval const
-required_capability: categorize_v4
+required_capability: categorize_v5
 
 FROM sample_data
   | EVAL message = ["Banana", "Bread"]
@@ -164,7 +164,7 @@ COUNT():long | category:keyword
 ;
 
 mv via eval const without aliases
-required_capability: categorize_v4
+required_capability: categorize_v5
 
 FROM sample_data
   | EVAL message = ["Banana", "Bread"]
@@ -178,7 +178,7 @@ COUNT():long | CATEGORIZE(message):keyword
 ;
 
 mv const in parameter
-required_capability: categorize_v4
+required_capability: categorize_v5
 
 FROM sample_data
   | STATS COUNT() BY c = CATEGORIZE(["Banana", "Bread"])
@@ -191,7 +191,7 @@ COUNT():long | c:keyword
 ;
 
 agg alias shadowing
-required_capability: categorize_v4
+required_capability: categorize_v5
 
 FROM sample_data
   | STATS c = COUNT() BY c = CATEGORIZE(["Banana", "Bread"])
@@ -206,7 +206,7 @@ c:keyword
 ;
 
 chained aggregations using categorize
-required_capability: categorize_v4
+required_capability: categorize_v5
 
 FROM sample_data
   | STATS COUNT() BY category=CATEGORIZE(message)
@@ -221,7 +221,7 @@ COUNT():long | category:keyword
 ;
 
 stats without aggs
-required_capability: categorize_v4
+required_capability: categorize_v5
 
 FROM sample_data
   | STATS BY category=CATEGORIZE(message)
@@ -235,7 +235,7 @@ category:keyword
 ;
 
 text field
-required_capability: categorize_v4
+required_capability: categorize_v5
 
 FROM hosts
   | STATS COUNT() BY category=CATEGORIZE(host_group)
@@ -253,7 +253,7 @@ COUNT():long | category:keyword
 ;
 
 on TO_UPPER
-required_capability: categorize_v4
+required_capability: categorize_v5
 
 FROM sample_data
   | STATS COUNT() BY category=CATEGORIZE(TO_UPPER(message))
@@ -267,7 +267,7 @@ COUNT():long | category:keyword
 ;
 
 on CONCAT
-required_capability: categorize_v4
+required_capability: categorize_v5
 
 FROM sample_data
   | STATS COUNT() BY category=CATEGORIZE(CONCAT(message, " banana"))
@@ -281,7 +281,7 @@ COUNT():long | category:keyword
 ;
 
 on CONCAT with unicode
-required_capability: categorize_v4
+required_capability: categorize_v5
 
 FROM sample_data
   | STATS COUNT() BY category=CATEGORIZE(CONCAT(message, " 👍🏽😊"))
@@ -295,7 +295,7 @@ COUNT():long | category:keyword
 ;
 
 on REVERSE(CONCAT())
-required_capability: categorize_v4
+required_capability: categorize_v5
 
 FROM sample_data
   | STATS COUNT() BY category=CATEGORIZE(REVERSE(CONCAT(message, " 👍🏽😊")))
@@ -309,7 +309,7 @@ COUNT():long | category:keyword
 ;
 
 and then TO_LOWER
-required_capability: categorize_v4
+required_capability: categorize_v5
 
 FROM sample_data
   | STATS COUNT() BY category=CATEGORIZE(message)
@@ -324,7 +324,7 @@ COUNT():long | category:keyword
 ;
 
 on const empty string
-required_capability: categorize_v4
+required_capability: categorize_v5
 
 FROM sample_data
   | STATS COUNT() BY category=CATEGORIZE("")
@@ -336,7 +336,7 @@ COUNT():long | category:keyword
 ;
 
 on const empty string from eval
-required_capability: categorize_v4
+required_capability: categorize_v5
 
 FROM sample_data
   | EVAL x = ""
@@ -349,7 +349,7 @@ COUNT():long | category:keyword
 ;
 
 on null
-required_capability: categorize_v4
+required_capability: categorize_v5
 
 FROM sample_data
   | EVAL x = null
@@ -362,7 +362,7 @@ COUNT():long | SUM(event_duration):long | category:keyword
 ;
 
 on null string
-required_capability: categorize_v4
+required_capability: categorize_v5
 
 FROM sample_data
   | EVAL x = null::string
@@ -375,7 +375,7 @@ COUNT():long | category:keyword
 ;
 
 filtering out all data
-required_capability: categorize_v4
+required_capability: categorize_v5
 
 FROM sample_data
   | WHERE @timestamp < "2023-10-23T00:00:00Z"
@@ -387,7 +387,7 @@ COUNT():long | category:keyword
 ;
 
 filtering out all data with constant
-required_capability: categorize_v4
+required_capability: categorize_v5
 
 FROM sample_data
   | STATS COUNT() BY category=CATEGORIZE(message)
@@ -398,7 +398,7 @@ COUNT():long | category:keyword
 ;
 
 drop output columns
-required_capability: categorize_v4
+required_capability: categorize_v5
 
 FROM sample_data
   | STATS count=COUNT() BY category=CATEGORIZE(message)
@@ -413,7 +413,7 @@ x:integer
 ;
 
 category value processing
-required_capability: categorize_v4
+required_capability: categorize_v5
 
 ROW message = ["connected to a", "connected to b", "disconnected"]
   | STATS COUNT() BY category=CATEGORIZE(message)
@@ -427,7 +427,7 @@ COUNT():long | category:keyword
 ;
 
 row aliases
-required_capability: categorize_v4
+required_capability: categorize_v5
 
 ROW message = "connected to xyz"
   | EVAL x = message
@@ -441,7 +441,7 @@ COUNT():long | category:keyword           | y:keyword
 ;
 
 from aliases
-required_capability: categorize_v4
+required_capability: categorize_v5
 
 FROM sample_data
   | EVAL x = message
@@ -457,7 +457,7 @@ COUNT():long | category:keyword         | y:keyword
 ;
 
 row aliases with keep
-required_capability: categorize_v4
+required_capability: categorize_v5
 
 ROW message = "connected to xyz"
   | EVAL x = message
@@ -473,7 +473,7 @@ COUNT():long | y:keyword
 ;
 
 from aliases with keep
-required_capability: categorize_v4
+required_capability: categorize_v5
 
 FROM sample_data
   | EVAL x = message
@@ -491,7 +491,7 @@ COUNT():long | y:keyword
 ;
 
 row rename
-required_capability: categorize_v4
+required_capability: categorize_v5
 
 ROW message = "connected to xyz"
   | RENAME message as x
@@ -505,7 +505,7 @@ COUNT():long | y:keyword
 ;
 
 from rename
-required_capability: categorize_v4
+required_capability: categorize_v5
 
 FROM sample_data
   | RENAME message as x
@@ -521,7 +521,7 @@ COUNT():long | y:keyword
 ;
 
 row drop
-required_capability: categorize_v4
+required_capability: categorize_v5
 
 ROW message = "connected to a"
   | STATS c = COUNT() BY category=CATEGORIZE(message)
@@ -534,7 +534,7 @@ c:long
 ;
 
 from drop
-required_capability: categorize_v4
+required_capability: categorize_v5
 
 FROM sample_data
   | STATS c = COUNT() BY category=CATEGORIZE(message)
@@ -547,3 +547,48 @@ c:long
 3
 3
 ;
+
+categorize in aggs inside function
+required_capability: categorize_v5
+
+FROM sample_data
+  | STATS COUNT(), x = MV_APPEND(category, category) BY category=CATEGORIZE(message)
+  | SORT x
+  | KEEP `COUNT()`, x
+;
+
+COUNT():long | x:keyword
+           3 | [.*?Connected.+?to.*?,.*?Connected.+?to.*?]
+           3 | [.*?Connection.+?error.*?,.*?Connection.+?error.*?]
+           1 | [.*?Disconnected.*?,.*?Disconnected.*?]
+;
+
+categorize in aggs same as grouping inside function
+required_capability: categorize_v5
+
+FROM sample_data
+  | STATS COUNT(), x = MV_APPEND(CATEGORIZE(message), `CATEGORIZE(message)`) BY CATEGORIZE(message)
+  | SORT x
+  | KEEP `COUNT()`, x
+;
+
+COUNT():long | x:keyword
+           3 | [.*?Connected.+?to.*?,.*?Connected.+?to.*?]
+           3 | [.*?Connection.+?error.*?,.*?Connection.+?error.*?]
+           1 | [.*?Disconnected.*?,.*?Disconnected.*?]
+;
+
+categorize in aggs same as grouping inside function with explicit alias
+required_capability: categorize_v5
+
+FROM sample_data
+  | STATS COUNT(), x = MV_APPEND(CATEGORIZE(message), category) BY category=CATEGORIZE(message)
+  | SORT x
+  | KEEP `COUNT()`, x
+;
+
+COUNT():long | x:keyword
+           3 | [.*?Connected.+?to.*?,.*?Connected.+?to.*?]
+           3 | [.*?Connection.+?error.*?,.*?Connection.+?error.*?]
+           1 | [.*?Disconnected.*?,.*?Disconnected.*?]
+;

+ 1 - 1
x-pack/plugin/esql/qa/testFixtures/src/main/resources/docs.csv-spec

@@ -678,7 +678,7 @@ Bangalore     | 9                 | 72
 ;
 
 docsCategorize
-required_capability: categorize_v4
+required_capability: categorize_v5
 // tag::docsCategorize[]
 FROM sample_data
 | STATS count=COUNT() BY category=CATEGORIZE(message)

+ 1 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java

@@ -396,7 +396,7 @@ public class EsqlCapabilities {
         /**
          * Supported the text categorization function "CATEGORIZE".
          */
-        CATEGORIZE_V4,
+        CATEGORIZE_V5,
 
         /**
          * QSTR function

+ 18 - 21
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java

@@ -20,7 +20,6 @@ import org.elasticsearch.xpack.esql.core.expression.Expression;
 import org.elasticsearch.xpack.esql.core.expression.Expressions;
 import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
 import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
-import org.elasticsearch.xpack.esql.core.expression.NameId;
 import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
 import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
 import org.elasticsearch.xpack.esql.core.expression.function.Function;
@@ -63,12 +62,10 @@ import org.elasticsearch.xpack.esql.stats.Metrics;
 import java.util.ArrayList;
 import java.util.BitSet;
 import java.util.Collection;
-import java.util.HashMap;
 import java.util.HashSet;
 import java.util.LinkedHashSet;
 import java.util.List;
 import java.util.Locale;
-import java.util.Map;
 import java.util.Set;
 import java.util.function.BiConsumer;
 import java.util.function.Consumer;
@@ -364,35 +361,35 @@ public class Verifier {
                 );
         });
 
-        // Forbid CATEGORIZE being used in the aggregations
-        agg.aggregates().forEach(a -> {
-            a.forEachDown(
-                Categorize.class,
-                categorize -> failures.add(
-                    fail(categorize, "cannot use CATEGORIZE grouping function [{}] within the aggregations", categorize.sourceText())
+        // Forbid CATEGORIZE being used in the aggregations, unless it appears as a grouping
+        agg.aggregates()
+            .forEach(
+                a -> a.forEachDown(
+                    AggregateFunction.class,
+                    aggregateFunction -> aggregateFunction.forEachDown(
+                        Categorize.class,
+                        categorize -> failures.add(
+                            fail(categorize, "cannot use CATEGORIZE grouping function [{}] within an aggregation", categorize.sourceText())
+                        )
+                    )
                 )
             );
-        });
 
-        // Forbid CATEGORIZE being referenced in the aggregation functions
-        Map<NameId, Categorize> categorizeByAliasId = new HashMap<>();
+        // Forbid CATEGORIZE being referenced as a child of an aggregation function
+        AttributeMap<Categorize> categorizeByAttribute = new AttributeMap<>();
         agg.groupings().forEach(g -> {
             g.forEachDown(Alias.class, alias -> {
                 if (alias.child() instanceof Categorize categorize) {
-                    categorizeByAliasId.put(alias.id(), categorize);
+                    categorizeByAttribute.put(alias.toAttribute(), categorize);
                 }
             });
         });
         agg.aggregates()
             .forEach(a -> a.forEachDown(AggregateFunction.class, aggregate -> aggregate.forEachDown(Attribute.class, attribute -> {
-                var categorize = categorizeByAliasId.get(attribute.id());
+                var categorize = categorizeByAttribute.get(attribute);
                 if (categorize != null) {
                     failures.add(
-                        fail(
-                            attribute,
-                            "cannot reference CATEGORIZE grouping function [{}] within the aggregations",
-                            attribute.sourceText()
-                        )
+                        fail(attribute, "cannot reference CATEGORIZE grouping function [{}] within an aggregation", attribute.sourceText())
                     );
                 }
             })));
@@ -449,7 +446,7 @@ public class Verifier {
                 // check the bucketing function against the group
                 else if (c instanceof GroupingFunction gf) {
                     if (Expressions.anyMatch(groups, ex -> ex instanceof Alias a && a.child().semanticEquals(gf)) == false) {
-                        failures.add(fail(gf, "can only use grouping function [{}] part of the BY clause", gf.sourceText()));
+                        failures.add(fail(gf, "can only use grouping function [{}] as part of the BY clause", gf.sourceText()));
                     }
                 }
             });
@@ -466,7 +463,7 @@ public class Verifier {
             // optimizer will later unroll expressions with aggs and non-aggs with a grouping function into an EVAL, but that will no longer
             // be verified (by check above in checkAggregate()), so do it explicitly here
             if (Expressions.anyMatch(groups, ex -> ex instanceof Alias a && a.child().semanticEquals(gf)) == false) {
-                failures.add(fail(gf, "can only use grouping function [{}] part of the BY clause", gf.sourceText()));
+                failures.add(fail(gf, "can only use grouping function [{}] as part of the BY clause", gf.sourceText()));
             } else if (level == 0) {
                 addFailureOnGroupingUsedNakedInAggs(failures, gf, "function");
             }

+ 16 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceAggregateAggExpressionWithEval.java

@@ -9,18 +9,21 @@ package org.elasticsearch.xpack.esql.optimizer.rules.logical;
 
 import org.elasticsearch.common.util.Maps;
 import org.elasticsearch.xpack.esql.core.expression.Alias;
+import org.elasticsearch.xpack.esql.core.expression.Attribute;
 import org.elasticsearch.xpack.esql.core.expression.AttributeMap;
 import org.elasticsearch.xpack.esql.core.expression.Expression;
 import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
 import org.elasticsearch.xpack.esql.core.tree.Source;
 import org.elasticsearch.xpack.esql.core.util.Holder;
 import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
+import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
 import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
 import org.elasticsearch.xpack.esql.plan.logical.Eval;
 import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
 import org.elasticsearch.xpack.esql.plan.logical.Project;
 
 import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 
@@ -51,6 +54,16 @@ public final class ReplaceAggregateAggExpressionWithEval extends OptimizerRules.
         AttributeMap<Expression> aliases = new AttributeMap<>();
         aggregate.forEachExpressionUp(Alias.class, a -> aliases.put(a.toAttribute(), a.child()));
 
+        // Build Categorize grouping functions map.
+        // Functions like BUCKET() shouldn't reach this point,
+        // as they are moved to an early EVAL by ReplaceAggregateNestedExpressionWithEval
+        Map<Categorize, Attribute> groupingAttributes = new HashMap<>();
+        aggregate.forEachExpressionUp(Alias.class, a -> {
+            if (a.child() instanceof Categorize groupingFunction) {
+                groupingAttributes.put(groupingFunction, a.toAttribute());
+            }
+        });
+
         // break down each aggregate into AggregateFunction and/or grouping key
         // preserve the projection at the end
         List<? extends NamedExpression> aggs = aggregate.aggregates();
@@ -109,6 +122,9 @@ public final class ReplaceAggregateAggExpressionWithEval extends OptimizerRules.
                         return alias.toAttribute();
                     });
 
+                    // replace grouping functions with their references
+                    aggExpression = aggExpression.transformUp(Categorize.class, groupingAttributes::get);
+
                     Alias alias = as.replaceChild(aggExpression);
                     newEvals.add(alias);
                     newProjections.add(alias.toAttribute());

+ 5 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceAggregateNestedExpressionWithEval.java

@@ -51,6 +51,7 @@ public final class ReplaceAggregateNestedExpressionWithEval extends OptimizerRul
             // Exception: Categorize is internal to the aggregation and remains in the groupings. We move its child expression into an eval.
             if (g instanceof Alias as) {
                 if (as.child() instanceof Categorize cat) {
+                    // For Categorize grouping function, we only move the child expression into an eval
                     if (cat.field() instanceof Attribute == false) {
                         groupingChanged = true;
                         var fieldAs = new Alias(as.source(), as.name(), cat.field(), null, true);
@@ -59,7 +60,6 @@ public final class ReplaceAggregateNestedExpressionWithEval extends OptimizerRul
                         evalNames.put(fieldAs.name(), fieldAttr);
                         Categorize replacement = cat.replaceChildren(List.of(fieldAttr));
                         newGroupings.set(i, as.replaceChild(replacement));
-                        groupingAttributes.put(cat, fieldAttr);
                     }
                 } else {
                     groupingChanged = true;
@@ -135,6 +135,10 @@ public final class ReplaceAggregateNestedExpressionWithEval extends OptimizerRul
                 });
                 // replace any grouping functions with their references pointing to the added synthetic eval
                 replaced = replaced.transformDown(GroupingFunction.class, gf -> {
+                    // Categorize in aggs depends on the grouping result, not on an early eval
+                    if (gf instanceof Categorize) {
+                        return gf;
+                    }
                     aggsChanged.set(true);
                     // should never return null, as it's verified.
                     // but even if broken, the transform will fail safely; otoh, returning `gf` will fail later due to incorrect plan.

+ 20 - 14
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java

@@ -407,12 +407,12 @@ public class VerifierTests extends ESTestCase {
 
         // but fails if it's different
         assertEquals(
-            "1:32: can only use grouping function [bucket(a, 3)] part of the BY clause",
+            "1:32: can only use grouping function [bucket(a, 3)] as part of the BY clause",
             error("row a = 1 | stats sum(a) where bucket(a, 3) > -1 by bucket(a,2)")
         );
 
         assertEquals(
-            "1:40: can only use grouping function [bucket(salary, 10)] part of the BY clause",
+            "1:40: can only use grouping function [bucket(salary, 10)] as part of the BY clause",
             error("from test | stats max(languages) WHERE bucket(salary, 10) > 1 by emp_no")
         );
 
@@ -444,19 +444,19 @@ public class VerifierTests extends ESTestCase {
 
     public void testGroupingInsideAggsAsAgg() {
         assertEquals(
-            "1:18: can only use grouping function [bucket(emp_no, 5.)] part of the BY clause",
+            "1:18: can only use grouping function [bucket(emp_no, 5.)] as part of the BY clause",
             error("from test| stats bucket(emp_no, 5.) by emp_no")
         );
         assertEquals(
-            "1:18: can only use grouping function [bucket(emp_no, 5.)] part of the BY clause",
+            "1:18: can only use grouping function [bucket(emp_no, 5.)] as part of the BY clause",
             error("from test| stats bucket(emp_no, 5.)")
         );
         assertEquals(
-            "1:18: can only use grouping function [bucket(emp_no, 5.)] part of the BY clause",
+            "1:18: can only use grouping function [bucket(emp_no, 5.)] as part of the BY clause",
             error("from test| stats bucket(emp_no, 5.) by bucket(emp_no, 6.)")
         );
         assertEquals(
-            "1:22: can only use grouping function [bucket(emp_no, 5.)] part of the BY clause",
+            "1:22: can only use grouping function [bucket(emp_no, 5.)] as part of the BY clause",
             error("from test| stats 3 + bucket(emp_no, 5.) by bucket(emp_no, 6.)")
         );
     }
@@ -1846,7 +1846,7 @@ public class VerifierTests extends ESTestCase {
     }
 
     public void testCategorizeSingleGrouping() {
-        assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V4.isEnabled());
+        assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V5.isEnabled());
 
         query("from test | STATS COUNT(*) BY CATEGORIZE(first_name)");
         query("from test | STATS COUNT(*) BY cat = CATEGORIZE(first_name)");
@@ -1875,7 +1875,7 @@ public class VerifierTests extends ESTestCase {
     }
 
     public void testCategorizeNestedGrouping() {
-        assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V4.isEnabled());
+        assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V5.isEnabled());
 
         query("from test | STATS COUNT(*) BY CATEGORIZE(LENGTH(first_name)::string)");
 
@@ -1890,27 +1890,33 @@ public class VerifierTests extends ESTestCase {
     }
 
     public void testCategorizeWithinAggregations() {
-        assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V4.isEnabled());
+        assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V5.isEnabled());
 
         query("from test | STATS MV_COUNT(cat), COUNT(*) BY cat = CATEGORIZE(first_name)");
+        query("from test | STATS MV_COUNT(CATEGORIZE(first_name)), COUNT(*) BY cat = CATEGORIZE(first_name)");
+        query("from test | STATS MV_COUNT(CATEGORIZE(first_name)), COUNT(*) BY CATEGORIZE(first_name)");
 
         assertEquals(
-            "1:25: cannot use CATEGORIZE grouping function [CATEGORIZE(first_name)] within the aggregations",
+            "1:25: cannot use CATEGORIZE grouping function [CATEGORIZE(first_name)] within an aggregation",
             error("FROM test | STATS COUNT(CATEGORIZE(first_name)) BY CATEGORIZE(first_name)")
         );
-
         assertEquals(
-            "1:25: cannot reference CATEGORIZE grouping function [cat] within the aggregations",
+            "1:25: cannot reference CATEGORIZE grouping function [cat] within an aggregation",
             error("FROM test | STATS COUNT(cat) BY cat = CATEGORIZE(first_name)")
         );
         assertEquals(
-            "1:30: cannot reference CATEGORIZE grouping function [cat] within the aggregations",
+            "1:30: cannot reference CATEGORIZE grouping function [cat] within an aggregation",
             error("FROM test | STATS SUM(LENGTH(cat::keyword) + LENGTH(last_name)) BY cat = CATEGORIZE(first_name)")
         );
         assertEquals(
-            "1:25: cannot reference CATEGORIZE grouping function [`CATEGORIZE(first_name)`] within the aggregations",
+            "1:25: cannot reference CATEGORIZE grouping function [`CATEGORIZE(first_name)`] within an aggregation",
             error("FROM test | STATS COUNT(`CATEGORIZE(first_name)`) BY CATEGORIZE(first_name)")
         );
+
+        assertEquals(
+            "1:28: can only use grouping function [CATEGORIZE(last_name)] as part of the BY clause",
+            error("FROM test | STATS MV_COUNT(CATEGORIZE(last_name)) BY CATEGORIZE(first_name)")
+        );
     }
 
     public void testSortByAggregate() {

+ 2 - 2
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java

@@ -1212,7 +1212,7 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
      *   \_EsRelation[test][_meta_field{f}#23, emp_no{f}#17, first_name{f}#18, ..]
      */
     public void testCombineProjectionWithCategorizeGrouping() {
-        assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V4.isEnabled());
+        assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V5.isEnabled());
 
         var plan = plan("""
             from test
@@ -3949,7 +3949,7 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
      *     \_EsRelation[test][_meta_field{f}#14, emp_no{f}#8, first_name{f}#9, ge..]
      */
     public void testNestedExpressionsInGroupsWithCategorize() {
-        assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V4.isEnabled());
+        assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V5.isEnabled());
 
         var plan = optimizedPlan("""
             from test