Browse Source

Disallow ES|QL CATEGORIZE in aggregation filters (#118319) (#118330)

Jan Kuipers 10 months ago
parent
commit
8f45b6bf4e

+ 17 - 3
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java

@@ -382,6 +382,18 @@ public class Verifier {
                     );
                 }
             })));
+        agg.aggregates().forEach(a -> a.forEachDown(FilteredExpression.class, fe -> fe.filter().forEachDown(Attribute.class, attribute -> {
+            var categorize = categorizeByAttribute.get(attribute);
+            if (categorize != null) {
+                failures.add(
+                    fail(
+                        attribute,
+                        "cannot reference CATEGORIZE grouping function [{}] within an aggregation filter",
+                        attribute.sourceText()
+                    )
+                );
+            }
+        })));
     }
 
     private static void checkRateAggregates(Expression expr, int nestedLevel, Set<Failure> failures) {
@@ -421,7 +433,8 @@ public class Verifier {
                 Expression filter = fe.filter();
                 failures.add(fail(filter, "WHERE clause allowed only for aggregate functions, none found in [{}]", fe.sourceText()));
             }
-            Expression f = fe.filter(); // check the filter has to be a boolean term, similar as checkFilterConditionType
+            Expression f = fe.filter();
+            // check the filter has to be a boolean term, similar as checkFilterConditionType
             if (f.dataType() != NULL && f.dataType() != BOOLEAN) {
                 failures.add(fail(f, "Condition expression needs to be boolean, found [{}]", f.dataType()));
             }
@@ -432,9 +445,10 @@ public class Verifier {
                         fail(af, "cannot use aggregate function [{}] in aggregate WHERE clause [{}]", af.sourceText(), fe.sourceText())
                     );
                 }
-                // check the bucketing function against the group
+                // check the grouping function against the group
                 else if (c instanceof GroupingFunction gf) {
-                    if (Expressions.anyMatch(groups, ex -> ex instanceof Alias a && a.child().semanticEquals(gf)) == false) {
+                    if (c instanceof Categorize
+                        || Expressions.anyMatch(groups, ex -> ex instanceof Alias a && a.child().semanticEquals(gf)) == false) {
                         failures.add(fail(gf, "can only use grouping function [{}] as part of the BY clause", gf.sourceText()));
                     }
                 }

+ 20 - 0
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java

@@ -1968,6 +1968,26 @@ public class VerifierTests extends ESTestCase {
         );
     }
 
+    public void testCategorizeWithFilteredAggregations() {
+        assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V5.isEnabled());
+
+        query("FROM test | STATS COUNT(*) WHERE first_name == \"John\" BY CATEGORIZE(last_name)");
+        query("FROM test | STATS COUNT(*) WHERE last_name == \"Doe\" BY CATEGORIZE(last_name)");
+
+        assertEquals(
+            "1:34: can only use grouping function [CATEGORIZE(first_name)] as part of the BY clause",
+            error("FROM test | STATS COUNT(*) WHERE CATEGORIZE(first_name) == \"John\" BY CATEGORIZE(last_name)")
+        );
+        assertEquals(
+            "1:34: can only use grouping function [CATEGORIZE(last_name)] as part of the BY clause",
+            error("FROM test | STATS COUNT(*) WHERE CATEGORIZE(last_name) == \"Doe\" BY CATEGORIZE(last_name)")
+        );
+        assertEquals(
+            "1:34: cannot reference CATEGORIZE grouping function [category] within an aggregation filter",
+            error("FROM test | STATS COUNT(*) WHERE category == \"Doe\" BY category = CATEGORIZE(last_name)")
+        );
+    }
+
     public void testSortByAggregate() {
         assertEquals("1:18: Aggregate functions are not allowed in SORT [COUNT]", error("ROW a = 1 | SORT count(*)"));
         assertEquals("1:28: Aggregate functions are not allowed in SORT [COUNT]", error("ROW a = 1 | SORT to_string(count(*))"));