Browse Source

QL: Further simplify tree traversals (#67211)

Follow-up to #67116, applying typed traversals in more places.

Relates #67116
Costin Leau 4 years ago
parent
commit
fc57255971

+ 1 - 6
x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/analysis/analyzer/Analyzer.java

@@ -633,12 +633,7 @@ public class Analyzer extends RuleExecutor<LogicalPlan> {
                     final Map<Attribute, Expression> collectRefs = new LinkedHashMap<>();
 
                     // collect aliases
-                    child.forEachUp(p -> p.forEachExpressionUp(e -> {
-                        if (e instanceof Alias) {
-                            Alias a = (Alias) e;
-                            collectRefs.put(a.toAttribute(), a.child());
-                        }
-                    }));
+                    child.forEachUp(p -> p.forEachExpressionUp(Alias.class, a -> collectRefs.put(a.toAttribute(), a.child())));
 
                     referencesStream = referencesStream.filter(r -> {
                         for (Attribute attr : child.outputSet()) {

+ 12 - 23
x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/analysis/analyzer/Verifier.java

@@ -46,6 +46,7 @@ import org.elasticsearch.xpack.sql.expression.function.Score;
 import org.elasticsearch.xpack.sql.expression.function.aggregate.Kurtosis;
 import org.elasticsearch.xpack.sql.expression.function.aggregate.Max;
 import org.elasticsearch.xpack.sql.expression.function.aggregate.Min;
+import org.elasticsearch.xpack.sql.expression.function.aggregate.NumericAggregate;
 import org.elasticsearch.xpack.sql.expression.function.aggregate.Skewness;
 import org.elasticsearch.xpack.sql.expression.function.aggregate.TopHits;
 import org.elasticsearch.xpack.sql.expression.function.scalar.Cast;
@@ -187,12 +188,7 @@ public final class Verifier {
             // collect Attribute sources
             // only Aliases are interesting since these are the only ones that hide expressions
             // FieldAttribute for example are self replicating.
-            plan.forEachExpressionUp(e -> {
-                if (e instanceof Alias) {
-                    Alias a = (Alias) e;
-                    collectRefs.put(a.toAttribute(), a.child());
-                }
-            });
+            plan.forEachExpressionUp(Alias.class, a -> collectRefs.put(a.toAttribute(), a.child()));
 
             AttributeMap<Expression> attributeRefs = collectRefs.build();
 
@@ -692,12 +688,11 @@ public final class Verifier {
 
     private static void checkForScoreInsideFunctions(LogicalPlan p, Set<Failure> localFailures) {
         // Make sure that SCORE is only used in "top level" functions
-        p.forEachExpression(e ->
-            e.forEachUp(Function.class, (Function f) ->
-                f.arguments().stream()
-                    .filter(exp -> exp.anyMatch(Score.class::isInstance))
-                    .forEach(exp -> localFailures.add(fail(exp, "[SCORE()] cannot be an argument to a function")))
-            ));
+        p.forEachExpression(Function.class, f ->
+            f.arguments().stream()
+                .filter(exp -> exp.anyMatch(Score.class::isInstance))
+                .forEach(exp -> localFailures.add(fail(exp, "[SCORE()] cannot be an argument to a function")))
+        );
     }
 
     private static void checkNestedUsedInGroupByOrHavingOrWhereOrOrderBy(LogicalPlan p, Set<Failure> localFailures,
@@ -858,20 +853,15 @@ public final class Verifier {
     private static void checkMatrixStats(LogicalPlan p, Set<Failure> localFailures) {
         // MatrixStats aggregate functions cannot operates on scalars
         // https://github.com/elastic/elasticsearch/issues/55344
-        p.forEachExpression(e -> e.forEachUp(Kurtosis.class, (Kurtosis s) -> {
-            if (s.field() instanceof Function) {
-                localFailures.add(fail(s.field(), "[{}()] cannot be used on top of operators or scalars", s.functionName()));
-            }
-        }));
-        p.forEachExpression(e -> e.forEachUp(Skewness.class, (Skewness s) -> {
-            if (s.field() instanceof Function) {
+        p.forEachExpressionUp(NumericAggregate.class, s -> {
+            if ((s instanceof Kurtosis || s instanceof Skewness) && s.field() instanceof Function) {
                 localFailures.add(fail(s.field(), "[{}()] cannot be used on top of operators or scalars", s.functionName()));
             }
-        }));
+        });
     }
 
     private static void checkCastOnInexact(LogicalPlan p, Set<Failure> localFailures) {
-        p.forEachDown(Filter.class, f -> f.forEachExpressionUp(e -> e.forEachUp(Cast.class, (Cast c) -> {
+        p.forEachDown(Filter.class, f -> f.forEachExpressionUp(Cast.class, c -> {
             if (c.field() instanceof FieldAttribute) {
                 EsField.Exact exactInfo = ((FieldAttribute) c.field()).getExactInfo();
                 if (exactInfo.hasExact() == false
@@ -880,8 +870,7 @@ public final class Verifier {
                         "[{}] of data type [{}] cannot be used for [{}()] inside the WHERE clause",
                         c.field().sourceText(), c.field().dataType().typeName(), c.functionName()));
                 }
-
             }
-        })));
+        }));
     }
 }

+ 21 - 44
x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/optimizer/Optimizer.java

@@ -217,12 +217,7 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
             final Map<Attribute, Expression> collectRefs = new LinkedHashMap<>();
 
             // collect aliases
-            plan.forEachExpressionUp(e -> {
-                if (e instanceof Alias) {
-                    Alias a = (Alias) e;
-                    collectRefs.put(a.toAttribute(), a.child());
-                }
-            });
+            plan.forEachExpressionUp(Alias.class, a -> collectRefs.put(a.toAttribute(), a.child()));
 
             plan = plan.transformUp(p -> {
                 // non attribute defining plans get their references removed
@@ -315,12 +310,9 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
                 // collect Attribute sources
                 // only Aliases are interesting since these are the only ones that hide expressions
                 // FieldAttribute for example are self replicating.
-                project.forEachUp(p -> p.forEachExpressionUp(e -> {
-                    if (e instanceof Alias) {
-                        Alias a = (Alias) e;
-                        if (a.child() instanceof Function) {
-                            collectRefs.put(a.toAttribute(), (Function) a.child());
-                        }
+                project.forEachUp(p -> p.forEachExpressionUp(Alias.class, a -> {
+                    if (a.child() instanceof Function) {
+                        collectRefs.put(a.toAttribute(), (Function) a.child());
                     }
                 }));
 
@@ -923,10 +915,8 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
             // 1. first check whether there are at least 2 aggs for the same fields so that there can be a promotion
             final Map<Expression, Match> potentialPromotions = new LinkedHashMap<>();
 
-            p.forEachExpressionUp(e -> {
-                if (Stats.isTypeCompatible(e)) {
-                    AggregateFunction f = (AggregateFunction) e;
-
+            p.forEachExpressionUp(AggregateFunction.class, f -> {
+                if (Stats.isTypeCompatible(f)) {
                     Expression argument = f.field();
                     Match match = potentialPromotions.get(argument);
 
@@ -971,13 +961,11 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
         public LogicalPlan apply(LogicalPlan plan) {
             final Map<Expression, Stats> statsPerField = new LinkedHashMap<>();
 
-            plan.forEachExpressionUp(e -> {
-                if (e instanceof Sum) {
-                    statsPerField.computeIfAbsent(((Sum) e).field(), field -> {
-                        Source source = new Source(field.sourceLocation(), "STATS(" + field.sourceText() + ")");
-                        return new Stats(source, field);
-                    });
-                }
+            plan.forEachExpressionUp(Sum.class, s -> {
+                statsPerField.computeIfAbsent(s.field(), field -> {
+                    Source source = new Source(field.sourceLocation(), "STATS(" + field.sourceText() + ")");
+                    return new Stats(source, field);
+                });
             });
 
             if (statsPerField.isEmpty() == false) {
@@ -995,13 +983,10 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
             final Map<Expression, ExtendedStats> seen = new LinkedHashMap<>();
 
             // count the extended stats
-            p.forEachExpressionUp(e -> {
-                if (e instanceof InnerAggregate) {
-                    InnerAggregate ia = (InnerAggregate) e;
-                    if (ia.outer() instanceof ExtendedStats) {
-                        ExtendedStats extStats = (ExtendedStats) ia.outer();
-                        seen.putIfAbsent(extStats.field(), extStats);
-                    }
+            p.forEachExpressionUp(InnerAggregate.class, ia -> {
+                if (ia.outer() instanceof ExtendedStats) {
+                    ExtendedStats extStats = (ExtendedStats) ia.outer();
+                    seen.putIfAbsent(extStats.field(), extStats);
                 }
             });
 
@@ -1043,13 +1028,9 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
         public LogicalPlan apply(LogicalPlan p) {
             Map<PercentileKey, Set<Expression>> percentsPerAggKey = new LinkedHashMap<>();
 
-            p.forEachExpressionUp(e -> {
-                if (e instanceof Percentile) {
-                    Percentile per = (Percentile) e;
-                    percentsPerAggKey.computeIfAbsent(new PercentileKey(per), v -> new LinkedHashSet<>())
-                        .add(per.percent());
-                }
-            });
+            p.forEachExpressionUp(Percentile.class, per ->
+                percentsPerAggKey.computeIfAbsent(new PercentileKey(per), v -> new LinkedHashSet<>()).add(per.percent())
+            );
 
             // create a Percentile agg for each agg key
             Map<PercentileKey, Percentiles> percentilesPerAggKey = new LinkedHashMap<>();
@@ -1072,13 +1053,9 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
         public LogicalPlan apply(LogicalPlan p) {
             final Map<PercentileKey, Set<Expression>> valuesPerAggKey = new LinkedHashMap<>();
 
-            p.forEachExpressionUp(e -> {
-                if (e instanceof PercentileRank) {
-                    PercentileRank per = (PercentileRank) e;
-                    valuesPerAggKey.computeIfAbsent(new PercentileKey(per), v -> new LinkedHashSet<>())
-                        .add(per.value());
-                }
-            });
+            p.forEachExpressionUp(PercentileRank.class, per ->
+                valuesPerAggKey.computeIfAbsent(new PercentileKey(per), v -> new LinkedHashSet<>()).add(per.value())
+            );
 
             // create a PercentileRank agg for each agg key
             Map<PercentileKey, PercentileRanks> ranksPerAggKey = new LinkedHashMap<>();