瀏覽代碼

SQL: Prevent grouping over grouping functions (#38649)

Improve verifier to disallow grouping over grouping functions (e.g.
HISTOGRAM over HISTOGRAM).

Close #38308
Costin Leau 6 年之前
父節點
當前提交
4e9b1cfd4d

+ 19 - 3
x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/analysis/analyzer/Verifier.java

@@ -593,20 +593,36 @@ public final class Verifier {
         // check if the query has a grouping function (Histogram) but no GROUP BY
         // check if the query has a grouping function (Histogram) but no GROUP BY
         if (p instanceof Project) {
         if (p instanceof Project) {
             Project proj = (Project) p;
             Project proj = (Project) p;
-            proj.projections().forEach(e -> e.forEachDown(f -> 
+            proj.projections().forEach(e -> e.forEachDown(f ->
                 localFailures.add(fail(f, "[{}] needs to be part of the grouping", Expressions.name(f))), GroupingFunction.class));
                 localFailures.add(fail(f, "[{}] needs to be part of the grouping", Expressions.name(f))), GroupingFunction.class));
         } else if (p instanceof Aggregate) {
         } else if (p instanceof Aggregate) {
-            // if it does have a GROUP BY, check if the groupings contain the grouping functions (Histograms) 
+            // if it does have a GROUP BY, check if the groupings contain the grouping functions (Histograms)
             Aggregate a = (Aggregate) p;
             Aggregate a = (Aggregate) p;
             a.aggregates().forEach(agg -> agg.forEachDown(e -> {
             a.aggregates().forEach(agg -> agg.forEachDown(e -> {
-                if (a.groupings().size() == 0 
+                if (a.groupings().size() == 0
                         || Expressions.anyMatch(a.groupings(), g -> g instanceof Function && e.functionEquals((Function) g)) == false) {
                         || Expressions.anyMatch(a.groupings(), g -> g instanceof Function && e.functionEquals((Function) g)) == false) {
                     localFailures.add(fail(e, "[{}] needs to be part of the grouping", Expressions.name(e)));
                     localFailures.add(fail(e, "[{}] needs to be part of the grouping", Expressions.name(e)));
                 }
                 }
+                else {
+                    checkGroupingFunctionTarget(e, localFailures);
+                }
+            }, GroupingFunction.class));
+
+            a.groupings().forEach(g -> g.forEachDown(e -> {
+                checkGroupingFunctionTarget(e, localFailures);
             }, GroupingFunction.class));
             }, GroupingFunction.class));
         }
         }
     }
     }
 
 
+    private static void checkGroupingFunctionTarget(GroupingFunction f, Set<Failure> localFailures) {
+        f.field().forEachDown(e -> {
+            if (e instanceof GroupingFunction) {
+                localFailures.add(fail(f.field(), "Cannot embed grouping functions within each other, found [{}] in [{}]",
+                        Expressions.name(f.field()), Expressions.name(f)));
+            }
+        });
+    }
+
     private static void checkFilterOnAggs(LogicalPlan p, Set<Failure> localFailures) {
     private static void checkFilterOnAggs(LogicalPlan p, Set<Failure> localFailures) {
         if (p instanceof Filter) {
         if (p instanceof Filter) {
             Filter filter = (Filter) p;
             Filter filter = (Filter) p;

+ 0 - 7
x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/Expression.java

@@ -14,9 +14,6 @@ import org.elasticsearch.xpack.sql.type.DataType;
 import org.elasticsearch.xpack.sql.util.StringUtils;
 import org.elasticsearch.xpack.sql.util.StringUtils;
 
 
 import java.util.List;
 import java.util.List;
-import java.util.Locale;
-
-import static java.lang.String.format;
 
 
 /**
 /**
  * In a SQL statement, an Expression is whatever a user specifies inside an
  * In a SQL statement, an Expression is whatever a user specifies inside an
@@ -39,10 +36,6 @@ public abstract class Expression extends Node<Expression> implements Resolvable
             this(true, message);
             this(true, message);
         }
         }
 
 
-        TypeResolution(String message, Object... args) {
-            this(true, format(Locale.ROOT, message, args));
-        }
-
         private TypeResolution(boolean unresolved, String message) {
         private TypeResolution(boolean unresolved, String message) {
             this.failed = unresolved;
             this.failed = unresolved;
             this.message = message;
             this.message = message;

+ 2 - 2
x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/Expressions.java

@@ -18,9 +18,9 @@ import java.util.Locale;
 import java.util.StringJoiner;
 import java.util.StringJoiner;
 import java.util.function.Predicate;
 import java.util.function.Predicate;
 
 
-import static java.lang.String.format;
 import static java.util.Collections.emptyList;
 import static java.util.Collections.emptyList;
 import static java.util.Collections.emptyMap;
 import static java.util.Collections.emptyMap;
+import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
 import static org.elasticsearch.xpack.sql.type.DataType.BOOLEAN;
 import static org.elasticsearch.xpack.sql.type.DataType.BOOLEAN;
 
 
 public final class Expressions {
 public final class Expressions {
@@ -186,7 +186,7 @@ public final class Expressions {
                                             String... acceptedTypes) {
                                             String... acceptedTypes) {
         return predicate.test(e.dataType()) || DataTypes.isNull(e.dataType())?
         return predicate.test(e.dataType()) || DataTypes.isNull(e.dataType())?
             TypeResolution.TYPE_RESOLVED :
             TypeResolution.TYPE_RESOLVED :
-            new TypeResolution(format(Locale.ROOT, "[%s]%s argument must be [%s], found value [%s] type [%s]",
+                new TypeResolution(format(null, "[{}]{} argument must be [{}], found value [{}] type [{}]",
                 operationName,
                 operationName,
                 paramOrd == null || paramOrd == ParamOrdinal.DEFAULT ? "" : " " + paramOrd.name().toLowerCase(Locale.ROOT),
                 paramOrd == null || paramOrd == ParamOrdinal.DEFAULT ? "" : " " + paramOrd.name().toLowerCase(Locale.ROOT),
                 acceptedTypesForErrorMsg(acceptedTypes),
                 acceptedTypesForErrorMsg(acceptedTypes),

+ 13 - 3
x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/analysis/analyzer/VerifierErrorMessagesTests.java

@@ -566,10 +566,20 @@ public class VerifierErrorMessagesTests extends ESTestCase {
     }
     }
 
 
     public void testAggsInHistogram() {
     public void testAggsInHistogram() {
-        assertEquals("1:47: Cannot use an aggregate [MAX] for grouping",
-                error("SELECT MAX(date) FROM test GROUP BY HISTOGRAM(MAX(int), 1)"));
+        assertEquals("1:37: Cannot use an aggregate [MAX] for grouping",
+                error("SELECT MAX(date) FROM test GROUP BY MAX(int)"));
     }
     }
-    
+
+    public void testGroupingsInHistogram() {
+        assertEquals(
+                "1:47: Cannot embed grouping functions within each other, found [HISTOGRAM(int, 1)] in [HISTOGRAM(HISTOGRAM(int, 1), 1)]",
+                error("SELECT MAX(date) FROM test GROUP BY HISTOGRAM(HISTOGRAM(int, 1), 1)"));
+    }
+
+    public void testCastInHistogram() {
+        accept("SELECT MAX(date) FROM test GROUP BY HISTOGRAM(CAST(int AS LONG), 1)");
+    }
+
     public void testHistogramNotInGrouping() {
     public void testHistogramNotInGrouping() {
         assertEquals("1:8: [HISTOGRAM(date, INTERVAL 1 MONTH)] needs to be part of the grouping",
         assertEquals("1:8: [HISTOGRAM(date, INTERVAL 1 MONTH)] needs to be part of the grouping",
                 error("SELECT HISTOGRAM(date, INTERVAL 1 MONTH) AS h FROM test"));
                 error("SELECT HISTOGRAM(date, INTERVAL 1 MONTH) AS h FROM test"));