1
0
Эх сурвалжийг харах

[8.x] ESQL: Fail in AggregateFunction when `LogicPlan` is not an `Aggregate` (#125400)

Manual 8.x backport of https://github.com/elastic/elasticsearch/pull/124446, with support for Metrics aggs and removed Dedup support
Iván Cea Fontenla 7 сар өмнө
parent
commit
744459568b

+ 6 - 0
docs/changelog/124446.yaml

@@ -0,0 +1,6 @@
+pr: 124446
+summary: "ESQL: Fail in `AggregateFunction` when `LogicPlan` is not an `Aggregate`"
+area: ES|QL
+type: bug
+issues:
+  - 124311

+ 9 - 9
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java

@@ -19,8 +19,8 @@ import org.elasticsearch.xpack.esql.core.expression.function.Function;
 import org.elasticsearch.xpack.esql.core.tree.Source;
 import org.elasticsearch.xpack.esql.core.util.CollectionUtils;
 import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
+import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
 import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
-import org.elasticsearch.xpack.esql.plan.logical.OrderBy;
 
 import java.io.IOException;
 import java.util.List;
@@ -139,14 +139,14 @@ public abstract class AggregateFunction extends Function implements PostAnalysis
     @Override
     public BiConsumer<LogicalPlan, Failures> postAnalysisPlanVerification() {
         return (p, failures) -> {
-            if (p instanceof OrderBy order) {
-                order.order().forEach(o -> {
-                    o.forEachDown(Function.class, f -> {
-                        if (f instanceof AggregateFunction) {
-                            failures.add(fail(f, "Aggregate functions are not allowed in SORT [{}]", f.functionName()));
-                        }
-                    });
-                });
+            if ((p instanceof Aggregate) == false) {
+                p.expressions().forEach(x -> x.forEachDown(AggregateFunction.class, af -> {
+                    if (af instanceof Rate) {
+                        failures.add(fail(af, "aggregate function [{}] not allowed outside METRICS command", af.sourceText()));
+                    } else {
+                        failures.add(fail(af, "aggregate function [{}] not allowed outside STATS command", af.sourceText()));
+                    }
+                }));
             }
         };
     }

+ 0 - 10
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Eval.java

@@ -24,8 +24,6 @@ import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute;
 import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
 import org.elasticsearch.xpack.esql.core.tree.Source;
 import org.elasticsearch.xpack.esql.core.type.DataType;
-import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
-import org.elasticsearch.xpack.esql.expression.function.aggregate.Rate;
 import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
 import org.elasticsearch.xpack.esql.plan.GeneratingPlan;
 
@@ -179,14 +177,6 @@ public class Eval extends UnaryPlan implements GeneratingPlan<Eval>, PostAnalysi
                     )
                 );
             }
-            // check no aggregate functions are used
-            field.forEachDown(AggregateFunction.class, af -> {
-                if (af instanceof Rate) {
-                    failures.add(fail(af, "aggregate function [{}] not allowed outside METRICS command", af.sourceText()));
-                } else {
-                    failures.add(fail(af, "aggregate function [{}] not allowed outside STATS command", af.sourceText()));
-                }
-            });
         });
     }
 }

+ 58 - 4
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java

@@ -978,6 +978,17 @@ public class VerifierTests extends ESTestCase {
             line 1:23: Unknown column [avg]""", error("from test | stats c = avg by missing + 1, not_found"));
     }
 
+    public void testMultipleAggsOutsideStats() {
+        assertEquals(
+            """
+                1:71: aggregate function [avg(salary)] not allowed outside STATS command
+                line 1:96: aggregate function [median(emp_no)] not allowed outside STATS command
+                line 1:22: aggregate function [sum(salary)] not allowed outside STATS command
+                line 1:39: aggregate function [avg(languages)] not allowed outside STATS command""",
+            error("from test | eval s = sum(salary), l = avg(languages) | where salary > avg(salary) and emp_no > median(emp_no)")
+        );
+    }
+
     public void testSpatialSort() {
         String prefix = "ROW wkt = [\"POINT(42.9711 -14.7553)\", \"POINT(75.8093 22.7277)\"] | MV_EXPAND wkt ";
         assertEquals("1:130: cannot sort on geo_point", error(prefix + "| EVAL shape = TO_GEOPOINT(wkt) | limit 5 | sort shape"));
@@ -2024,10 +2035,53 @@ public class VerifierTests extends ESTestCase {
     }
 
     public void testSortByAggregate() {
-        assertEquals("1:18: Aggregate functions are not allowed in SORT [COUNT]", error("ROW a = 1 | SORT count(*)"));
-        assertEquals("1:28: Aggregate functions are not allowed in SORT [COUNT]", error("ROW a = 1 | SORT to_string(count(*))"));
-        assertEquals("1:22: Aggregate functions are not allowed in SORT [MAX]", error("ROW a = 1 | SORT 1 + max(a)"));
-        assertEquals("1:18: Aggregate functions are not allowed in SORT [COUNT]", error("FROM test | SORT count(*)"));
+        assertEquals("1:18: aggregate function [count(*)] not allowed outside STATS command", error("ROW a = 1 | SORT count(*)"));
+        assertEquals(
+            "1:28: aggregate function [count(*)] not allowed outside STATS command",
+            error("ROW a = 1 | SORT to_string(count(*))")
+        );
+        assertEquals("1:22: aggregate function [max(a)] not allowed outside STATS command", error("ROW a = 1 | SORT 1 + max(a)"));
+        assertEquals("1:18: aggregate function [count(*)] not allowed outside STATS command", error("FROM test | SORT count(*)"));
+    }
+
+    public void testFilterByAggregate() {
+        assertEquals("1:19: aggregate function [count(*)] not allowed outside STATS command", error("ROW a = 1 | WHERE count(*) > 0"));
+        assertEquals(
+            "1:29: aggregate function [count(*)] not allowed outside STATS command",
+            error("ROW a = 1 | WHERE to_string(count(*)) IS NOT NULL")
+        );
+        assertEquals("1:23: aggregate function [max(a)] not allowed outside STATS command", error("ROW a = 1 | WHERE 1 + max(a) > 0"));
+        assertEquals(
+            "1:24: aggregate function [min(languages)] not allowed outside STATS command",
+            error("FROM employees | WHERE min(languages) > 2")
+        );
+    }
+
+    public void testDissectByAggregate() {
+        assertEquals(
+            "1:21: aggregate function [min(first_name)] not allowed outside STATS command",
+            error("from test | dissect min(first_name) \"%{foo}\"")
+        );
+        assertEquals(
+            "1:21: aggregate function [avg(salary)] not allowed outside STATS command",
+            error("from test | dissect avg(salary) \"%{foo}\"")
+        );
+    }
+
+    public void testGrokByAggregate() {
+        assertEquals(
+            "1:18: aggregate function [max(last_name)] not allowed outside STATS command",
+            error("from test | grok max(last_name) \"%{WORD:foo}\"")
+        );
+        assertEquals(
+            "1:18: aggregate function [sum(salary)] not allowed outside STATS command",
+            error("from test | grok sum(salary) \"%{WORD:foo}\"")
+        );
+    }
+
+    public void testAggregateInRow() {
+        assertEquals("1:13: aggregate function [count(*)] not allowed outside STATS command", error("ROW a = 1 + count(*)"));
+        assertEquals("1:9: aggregate function [avg(2)] not allowed outside STATS command", error("ROW a = avg(2)"));
     }
 
     public void testLookupJoinDataTypeMismatch() {