瀏覽代碼

SQL: Fix SUM(all zeroes) to return 0 instead of NULL (#65796)

Previously the SUM(all zeroes) was `NULL`, but after this change the SUM
SQL function call is automatically upgraded into a `stats` aggregation
instead of a `sum` aggregation. The `stats` aggregation only results in
`NULL` if the there were no rows, no values (all nulls) to aggregate, which 
is the expected behaviour across different SQL implementations.

This is a workaround for the issue #45251 . Once the results of the `sum` 
aggregation can differentiate between `SUM(all nulls)` and 
`SUM(all zeroes`) the optimizer rule introduced in this commit needs to be
removed.
Andras Palinkas 4 年之前
父節點
當前提交
b74792a8f2

+ 417 - 0
x-pack/plugin/sql/qa/server/src/main/resources/agg.csv-spec

@@ -1325,3 +1325,420 @@ F              |1964-10-18T00:00:00.000Z|1952-04-19T00:00:00.000Z
 M              |1965-01-03T00:00:00.000Z|1952-02-27T00:00:00.000Z
 ;
 
+
+//
+// Aggregations on NULLs and Zeros
+//
+
+allZerosWithFirst
+schema::FIRST_AllZeros:i
+SELECT FIRST(bytes_in) as "FIRST_AllZeros" FROM logs WHERE bytes_in = 0;
+
+FIRST_AllZeros 
+---------------
+0              
+;
+
+
+allNullsWithFirst
+schema::FIRST_AllNulls:i
+SELECT FIRST(bytes_out) as "FIRST_AllNulls" FROM logs WHERE bytes_out IS NULL;
+
+FIRST_AllNulls 
+---------------
+null           
+;
+
+
+allZerosWithLast
+schema::LAST_AllZeros:i
+SELECT LAST(bytes_in) as "LAST_AllZeros" FROM logs WHERE bytes_in = 0;
+
+ LAST_AllZeros 
+---------------
+0              
+;
+
+
+allNullsWithLast
+schema::LAST_AllNulls:i
+SELECT LAST(bytes_out) as "LAST_AllNulls" FROM logs WHERE bytes_out IS NULL;
+
+ LAST_AllNulls 
+---------------
+null           
+;
+
+
+allZerosWithCount
+schema::COUNT_AllZeros:l
+SELECT COUNT(bytes_in) as "COUNT_AllZeros" FROM logs WHERE bytes_in = 0;
+
+COUNT_AllZeros 
+---------------
+2              
+;
+
+
+allNullsWithCount
+schema::COUNT_AllNulls:l
+SELECT COUNT(bytes_out) as "COUNT_AllNulls" FROM logs WHERE bytes_out IS NULL;
+
+COUNT_AllNulls 
+---------------
+0              
+;
+
+
+
+allZerosWithAvg
+schema::AVG_AllZeros:d
+SELECT AVG(bytes_in) as "AVG_AllZeros" FROM logs WHERE bytes_in = 0;
+
+ AVG_AllZeros  
+---------------
+0.0            
+;
+
+
+allNullsWithAvg
+schema::AVG_AllNulls:d
+SELECT AVG(bytes_out) as "AVG_AllNulls" FROM logs WHERE bytes_out IS NULL;
+
+ AVG_AllNulls  
+---------------
+null           
+;
+
+
+allZerosWithMin
+schema::MIN_AllZeros:i
+SELECT MIN(bytes_in) as "MIN_AllZeros" FROM logs WHERE bytes_in = 0;
+
+ MIN_AllZeros  
+---------------
+0              
+;
+
+
+allNullsWithMin
+schema::MIN_AllNulls:i
+SELECT MIN(bytes_out) as "MIN_AllNulls" FROM logs WHERE bytes_out IS NULL;
+
+ MIN_AllNulls  
+---------------
+null           
+;
+
+
+allZerosWithMax
+schema::MAX_AllZeros:i
+SELECT MAX(bytes_in) as "MAX_AllZeros" FROM logs WHERE bytes_in = 0;
+
+ MAX_AllZeros  
+---------------
+0              
+;
+
+
+allNullsWithMax
+schema::MAX_AllNulls:i
+SELECT MAX(bytes_out) as "MAX_AllNulls" FROM logs WHERE bytes_out IS NULL;
+
+ MAX_AllNulls  
+---------------
+null           
+;
+
+
+allZerosWithSum
+schema::SUM_AllZeros:i
+SELECT SUM(bytes_in) as "SUM_AllZeros" FROM logs WHERE bytes_in = 0;
+
+ SUM_AllZeros  
+---------------
+0              
+;
+
+
+allNullsWithSum
+schema::SUM_AllNulls:i
+SELECT SUM(bytes_out) as "SUM_AllNulls" FROM logs WHERE bytes_out IS NULL;
+
+ SUM_AllNulls  
+---------------
+null           
+;
+
+
+allZerosWithPercentile
+schema::PERCENTILE_AllZeros:d
+SELECT PERCENTILE(bytes_in, 0) as "PERCENTILE_AllZeros" FROM logs WHERE bytes_in = 0;
+
+PERCENTILE_AllZeros
+-------------------
+0.0                
+;
+
+
+allNullsWithPercentile
+schema::PERCENTILE_AllNulls:d
+SELECT PERCENTILE(bytes_out, 0) as "PERCENTILE_AllNulls" FROM logs WHERE bytes_out IS NULL;
+
+PERCENTILE_AllNulls
+-------------------
+null               
+;
+
+
+allZerosWithPercentileRank
+schema::PERCENTILE_RANK_AllZeros:d
+SELECT PERCENTILE_RANK(bytes_in, 0) as "PERCENTILE_RANK_AllZeros" FROM logs WHERE bytes_in = 0;
+
+PERCENTILE_RANK_AllZeros
+------------------------
+100.0                   
+;
+
+
+allNullsWithPercentileRank
+schema::PERCENTILE_RANK_AllNulls:d
+SELECT PERCENTILE_RANK(bytes_out, 0) as "PERCENTILE_RANK_AllNulls" FROM logs WHERE bytes_out IS NULL;
+
+PERCENTILE_RANK_AllNulls
+------------------------
+null                    
+;
+
+
+allZerosWithSumOfSquares
+schema::SUM_OF_SQUARES_AllZeros:d
+SELECT SUM_OF_SQUARES(bytes_in) as "SUM_OF_SQUARES_AllZeros" FROM logs WHERE bytes_in = 0;
+
+SUM_OF_SQUARES_AllZeros
+-----------------------
+0.0                    
+;
+
+
+allNullsWithSumOfSquares
+schema::SUM_OF_SQUARES_AllNulls:d
+SELECT SUM_OF_SQUARES(bytes_out) as "SUM_OF_SQUARES_AllNulls" FROM logs WHERE bytes_out IS NULL;
+
+SUM_OF_SQUARES_AllNulls
+-----------------------
+null                   
+;
+
+
+allZerosWithStddevPop
+schema::STDDEV_POP_AllZeros:d
+SELECT STDDEV_POP(bytes_in) as "STDDEV_POP_AllZeros" FROM logs WHERE bytes_in = 0;
+
+STDDEV_POP_AllZeros
+-------------------
+0.0                
+;
+
+
+allNullsWithStddevPop
+schema::STDDEV_POP_AllNulls:d
+SELECT STDDEV_POP(bytes_out) as "STDDEV_POP_AllNulls" FROM logs WHERE bytes_out IS NULL;
+
+STDDEV_POP_AllNulls
+-------------------
+null               
+;
+
+
+allZerosWithStddevSamp
+schema::STDDEV_SAMP_AllZeros:d
+SELECT STDDEV_SAMP(bytes_in) as "STDDEV_SAMP_AllZeros" FROM logs WHERE bytes_in = 0;
+
+STDDEV_SAMP_AllZeros
+--------------------
+0.0                 
+;
+
+
+allNullsWithStddevSamp
+schema::STDDEV_SAMP_AllNulls:d
+SELECT STDDEV_SAMP(bytes_out) as "STDDEV_SAMP_AllNulls" FROM logs WHERE bytes_out IS NULL;
+
+STDDEV_SAMP_AllNulls
+--------------------
+null                
+;
+
+
+allZerosWithVarSamp
+schema::VAR_SAMP_AllZeros:d
+SELECT VAR_SAMP(bytes_in) as "VAR_SAMP_AllZeros" FROM logs WHERE bytes_in = 0;
+
+VAR_SAMP_AllZeros
+-----------------
+0.0              
+;
+
+
+allNullsWithVarSamp
+schema::VAR_SAMP_AllNulls:d
+SELECT VAR_SAMP(bytes_out) as "VAR_SAMP_AllNulls" FROM logs WHERE bytes_out IS NULL;
+
+VAR_SAMP_AllNulls
+-----------------
+null             
+;
+
+
+allZerosWithVarPop
+schema::VAR_POP_AllZeros:d
+SELECT VAR_POP(bytes_in) as "VAR_POP_AllZeros" FROM logs WHERE bytes_in = 0;
+
+VAR_POP_AllZeros
+----------------
+0.0             
+;
+
+
+allNullsWithVarPop
+schema::VAR_POP_AllNulls:d
+SELECT VAR_POP(bytes_out) as "VAR_POP_AllNulls" FROM logs WHERE bytes_out IS NULL;
+
+VAR_POP_AllNulls
+----------------
+null            
+;
+
+
+allZerosWithSkewness
+schema::SKEWNESS_AllZeros:d
+SELECT SKEWNESS(bytes_in) as "SKEWNESS_AllZeros" FROM logs WHERE bytes_in = 0;
+
+SKEWNESS_AllZeros
+-----------------
+NaN              
+;
+
+
+allNullsWithSkewness
+schema::SKEWNESS_AllNulls:d
+SELECT SKEWNESS(bytes_out) as "SKEWNESS_AllNulls" FROM logs WHERE bytes_out IS NULL;
+
+SKEWNESS_AllNulls
+-----------------
+null             
+;
+
+
+allZerosWithMad
+schema::MAD_AllZeros:d
+SELECT MAD(bytes_in) as "MAD_AllZeros" FROM logs WHERE bytes_in = 0;
+
+ MAD_AllZeros  
+---------------
+0.0            
+;
+
+
+allNullsWithMad
+schema::MAD_AllNulls:d
+SELECT MAD(bytes_out) as "MAD_AllNulls" FROM logs WHERE bytes_out IS NULL;
+
+ MAD_AllNulls  
+---------------
+NaN            
+;
+
+
+allZerosWithKurtosis
+schema::KURTOSIS_AllZeros:d
+SELECT KURTOSIS(bytes_in) as "KURTOSIS_AllZeros" FROM logs WHERE bytes_in = 0;
+
+KURTOSIS_AllZeros
+-----------------
+NaN              
+;
+
+
+allNullsWithKurtosis
+schema::KURTOSIS_AllNulls:d
+SELECT KURTOSIS(bytes_out) as "KURTOSIS_AllNulls" FROM logs WHERE bytes_out IS NULL;
+
+KURTOSIS_AllNulls
+-----------------
+null             
+;
+
+nullsAndZerosCombined
+schema::COUNT(*):l|COUNT_AllZeros:l|COUNT_AllNulls:l|FIRST_AllZeros:i|FIRST_AllNulls:i|SUM_AllZeros:i|SUM_AllNulls:i
+SELECT
+    COUNT(*), 
+    COUNT(bytes_in) AS "COUNT_AllZeros", 
+    COUNT(bytes_out) AS "COUNT_AllNulls", 
+    FIRST(bytes_in) AS "FIRST_AllZeros", 
+    FIRST(bytes_out) AS "FIRST_AllNulls", 
+    SUM(bytes_in) AS "SUM_AllZeros", 
+    SUM(bytes_out) AS "SUM_AllNulls"
+FROM logs
+WHERE bytes_in = 0 AND bytes_out IS NULL;
+
+   COUNT(*)    |COUNT(bytes_in)|COUNT(bytes_out)|FIRST_AllZeros |FIRST_AllNulls | SUM_AllZeros  | SUM_AllNulls  
+---------------+---------------+----------------+---------------+---------------+---------------+---------------
+1              |1              |0               |0              |null           |0              |null           
+;
+
+
+groupedByNullsAndZeros
+schema::bytes_in:i|COUNT(*):l|SUM(bytes_in):i|MIN(bytes_in):i|MAX(bytes_in):i|AVG(bytes_in):d
+SELECT
+    bytes_in, 
+    COUNT(*), 
+    SUM(bytes_in), 
+    MIN(bytes_in), 
+    MAX(bytes_in), 
+    AVG(bytes_in)
+FROM logs
+WHERE NVL(bytes_in, 0) = 0
+GROUP BY bytes_in
+ORDER BY bytes_in DESC NULLS LAST;
+
+   bytes_in    |   COUNT(*)    | SUM(bytes_in) | MIN(bytes_in) | MAX(bytes_in) | AVG(bytes_in) 
+---------------+---------------+---------------+---------------+---------------+---------------
+0              |2              |0              |0              |0              |0.0            
+null           |1              |null           |null           |null           |null           
+;
+
+groupedByMultipleSumsWithNullsAndZeros
+schema::SUM(bytes_in):i|SUM(bytes_out):i|client_ip:s|c:l
+SELECT
+  SUM(bytes_in),
+  SUM(bytes_out),
+  client_ip,
+  COUNT(*) AS c
+FROM logs
+WHERE client_ip = '10.0.0.0/16' AND NVL(bytes_out, 0) = 0
+GROUP BY client_ip
+ORDER BY c DESC, SUM(bytes_in) ASC NULLS FIRST;
+
+ SUM(bytes_in) |SUM(bytes_out) |   client_ip   |       c       
+---------------+---------------+---------------+---------------
+232            |null           |10.0.1.199     |10             
+124            |null           |10.0.1.166     |7              
+336            |null           |10.0.1.122     |7              
+8              |null           |10.0.1.205     |2              
+16             |null           |10.0.1.201     |2              
+16             |null           |10.0.1.203     |2              
+28             |null           |10.0.1.207     |2              
+40             |null           |10.0.1.222     |2              
+56             |null           |10.0.0.130     |2              
+null           |null           |10.0.2.129     |1              
+8              |null           |10.0.1.202     |1              
+8              |null           |10.0.1.206     |1              
+8              |null           |10.0.1.208     |1              
+16             |null           |10.0.1.13      |1              
+28             |null           |10.0.0.107     |1              
+30             |null           |10.0.0.147     |1              
+32             |null           |10.0.1.177     |1              
+48             |null           |10.0.0.109     |1              
+;

+ 1 - 0
x-pack/plugin/sql/qa/server/src/main/resources/logs.csv

@@ -99,3 +99,4 @@ id,@timestamp,bytes_in,bytes_out,client_ip,client_port,dest_ip,status
 98,2017-11-10T21:12:24Z,74,90,10.0.0.134,57203,172.20.10.1,OK
 99,2017-11-10T21:17:37Z,39,512,10.0.0.128,29333,,OK
 100,2017-11-10T03:21:36Z,64,183,10.0.0.129,4541,172.16.1.1,OK
+101,2017-11-10T23:22:36Z,,,10.0.2.129,4541,172.20.11.1,OK

+ 12 - 1
x-pack/plugin/sql/qa/server/src/main/resources/pivot.csv-spec

@@ -197,10 +197,21 @@ null                 |10043          |Yishay         |M              |1990-10-20
 null                 |10044          |Mingsen        |F              |1994-05-21 00:00:00.0|Casley         |39728          |null
 1952-04-19 00:00:00.0|10009          |Sumant         |F              |1985-02-18 00:00:00.0|Peac           |66174          |null
 1953-01-07 00:00:00.0|10067          |Claudi         |M              |1987-03-04 00:00:00.0|Stavenow       |null           |52044
-
 // end::sumWithoutSubquery
 ;
 
+sumWithZeros
+SELECT *
+FROM (SELECT client_ip, status, bytes_in FROM logs WHERE NVL(bytes_in, 0) = 0)
+PIVOT (SUM(bytes_in) FOR status IN ('OK','Error'));
+
+   client_ip   |     'OK'      |    'Error'    
+---------------+---------------+---------------
+10.0.1.199     |0              |null           
+10.0.1.205     |0              |null           
+10.0.2.129     |null           |null           
+;
+
 sumWithInnerAggregateSumOfSquares
 schema::birth_date:ts|emp_no:i|first_name:s|gender:s|hire_date:ts|last_name:s|1:d|2:d
 SELECT * FROM test_emp PIVOT (SUM_OF_SQUARES(salary) FOR languages IN (1, 2)) LIMIT 5;

+ 34 - 0
x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/optimizer/Optimizer.java

@@ -164,6 +164,7 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
                 new ReplaceAggsWithMatrixStats(),
                 new ReplaceAggsWithExtendedStats(),
                 new ReplaceAggsWithStats(),
+                new ReplaceSumWithStats(),
                 new PromoteStatsToExtendedStats(),
                 new ReplaceAggsWithPercentiles(),
                 new ReplaceAggsWithPercentileRanks()
@@ -983,6 +984,39 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
         }
     }
 
+    // This class is a workaround for the SUM(all zeros) = NULL issue raised in https://github.com/elastic/elasticsearch/issues/45251 and
+    // should be removed as soon as root cause is fixed and the sum aggregation results can differentiate between SUM(all zeroes) 
+    // and SUM(all nulls)
+    // NOTE: this rule should always be applied AFTER the ReplaceAggsWithStats rule
+    static class ReplaceSumWithStats extends OptimizerBasicRule {
+        
+        @Override 
+        public LogicalPlan apply(LogicalPlan plan) {
+            final Map<Expression, Stats> statsPerField = new LinkedHashMap<>();
+            
+            plan.forEachExpressionsUp(e -> {
+                if (e instanceof Sum) {
+                    statsPerField.computeIfAbsent(((Sum) e).field(), field -> {
+                        Source source = new Source(field.sourceLocation(), "STATS(" + field.sourceText() + ")");
+                        return new Stats(source, field);
+                    });
+                }
+            });
+            
+            if (statsPerField.isEmpty() == false) {
+                plan = plan.transformExpressionsUp(e -> {
+                    if (e instanceof Sum) {
+                        Sum sum = (Sum) e;
+                        return new InnerAggregate(sum, statsPerField.get(sum.field()));
+                    }
+                    return e;
+                });
+            }
+            
+            return plan;
+        }
+    }
+
     static class PromoteStatsToExtendedStats extends OptimizerBasicRule {
 
         @Override

+ 67 - 28
x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/optimizer/OptimizerTests.java

@@ -119,7 +119,6 @@ import org.elasticsearch.xpack.sql.plan.logical.command.ShowTables;
 import org.elasticsearch.xpack.sql.session.EmptyExecutable;
 
 import java.lang.reflect.Constructor;
-import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
 
@@ -219,7 +218,7 @@ public class OptimizerTests extends ESTestCase {
         // WHERE a < 10
         LogicalPlan p = new Filter(EMPTY, FROM(), lessThanOf(a, L(10)));
         // SELECT
-        p = new Project(EMPTY, p, Arrays.asList(a, b));
+        p = new Project(EMPTY, p, asList(a, b));
         // ORDER BY
         p = new OrderBy(EMPTY, p, singletonList(new Order(EMPTY, b, OrderDirection.ASC, null)));
 
@@ -269,14 +268,14 @@ public class OptimizerTests extends ESTestCase {
 
     public void testConstantFoldingIn() {
         In in = new In(EMPTY, ONE,
-            Arrays.asList(ONE, TWO, ONE, THREE, new Sub(EMPTY, THREE, ONE), ONE, FOUR, new Abs(EMPTY, new Sub(EMPTY, TWO, FIVE))));
+            asList(ONE, TWO, ONE, THREE, new Sub(EMPTY, THREE, ONE), ONE, FOUR, new Abs(EMPTY, new Sub(EMPTY, TWO, FIVE))));
         Literal result= (Literal) new ConstantFolding().rule(in);
         assertEquals(true, result.value());
     }
 
     public void testConstantFoldingIn_LeftValueNotFoldable() {
         In in = new In(EMPTY, getFieldAttribute(),
-                Arrays.asList(ONE, TWO, ONE, THREE, new Sub(EMPTY, THREE, ONE), ONE, FOUR, new Abs(EMPTY, new Sub(EMPTY, TWO, FIVE))));
+                asList(ONE, TWO, ONE, THREE, new Sub(EMPTY, THREE, ONE), ONE, FOUR, new Abs(EMPTY, new Sub(EMPTY, TWO, FIVE))));
         Alias as = new Alias(in.source(), in.sourceText(), in);
         Project p = new Project(EMPTY, FROM(), Collections.singletonList(as));
         p = (Project) new ConstantFolding().apply(p);
@@ -287,13 +286,13 @@ public class OptimizerTests extends ESTestCase {
     }
 
     public void testConstantFoldingIn_RightValueIsNull() {
-        In in = new In(EMPTY, getFieldAttribute(), Arrays.asList(NULL, NULL));
+        In in = new In(EMPTY, getFieldAttribute(), asList(NULL, NULL));
         Literal result= (Literal) new ConstantFolding().rule(in);
         assertNull(result.value());
     }
 
     public void testConstantFoldingIn_LeftValueIsNull() {
-        In in = new In(EMPTY, NULL, Arrays.asList(ONE, TWO, THREE));
+        In in = new In(EMPTY, NULL, asList(ONE, TWO, THREE));
         Literal result= (Literal) new ConstantFolding().rule(in);
         assertNull(result.value());
     }
@@ -426,9 +425,9 @@ public class OptimizerTests extends ESTestCase {
         Class<ArbitraryConditionalFunction> clazz =
             (Class<ArbitraryConditionalFunction>) randomFrom(Coalesce.class, Greatest.class, Least.class);
         Constructor<ArbitraryConditionalFunction> ctor = clazz.getConstructor(Source.class, List.class);
-        ArbitraryConditionalFunction conditionalFunction = ctor.newInstance(EMPTY, Arrays.asList(NULL, ONE, TWO));
+        ArbitraryConditionalFunction conditionalFunction = ctor.newInstance(EMPTY, asList(NULL, ONE, TWO));
         assertEquals(conditionalFunction, rule.rule(conditionalFunction));
-        conditionalFunction = ctor.newInstance(EMPTY, Arrays.asList(NULL, NULL, NULL));
+        conditionalFunction = ctor.newInstance(EMPTY, asList(NULL, NULL, NULL));
         assertEquals(conditionalFunction, rule.rule(conditionalFunction));
     }
 
@@ -461,7 +460,7 @@ public class OptimizerTests extends ESTestCase {
     public void testSimplifyCoalesceFirstLiteral() {
         Expression e = new SimplifyConditional()
                 .rule(new Coalesce(EMPTY,
-                        Arrays.asList(NULL, TRUE, FALSE, new Abs(EMPTY, getFieldAttribute()))));
+                        asList(NULL, TRUE, FALSE, new Abs(EMPTY, getFieldAttribute()))));
         assertEquals(Coalesce.class, e.getClass());
         assertEquals(1, e.children().size());
         assertEquals(TRUE, e.children().get(0));
@@ -585,7 +584,7 @@ public class OptimizerTests extends ESTestCase {
         // ELSE 'default'
         // END
 
-        Case c = new Case(EMPTY, Arrays.asList(
+        Case c = new Case(EMPTY, asList(
                 new IfConditional(EMPTY, equalsOf(getFieldAttribute(), ONE), literal("foo1")),
                 new IfConditional(EMPTY, equalsOf(ONE, TWO), literal("bar1")),
                 new IfConditional(EMPTY, equalsOf(TWO, ONE), literal("bar2")),
@@ -611,7 +610,7 @@ public class OptimizerTests extends ESTestCase {
         //
         // 'foo2'
 
-        Case c = new Case(EMPTY, Arrays.asList(
+        Case c = new Case(EMPTY, asList(
                 new IfConditional(EMPTY, equalsOf(ONE, TWO), literal("foo1")),
                 new IfConditional(EMPTY, equalsOf(ONE, ONE), literal("foo2")), literal("default")));
         assertFalse(c.foldable());
@@ -636,7 +635,7 @@ public class OptimizerTests extends ESTestCase {
         //
         // myField (non-foldable)
 
-        Case c = new Case(EMPTY, Arrays.asList(
+        Case c = new Case(EMPTY, asList(
                 new IfConditional(EMPTY, equalsOf(ONE, TWO), literal("foo1")),
                 getFieldAttribute("myField")));
         assertFalse(c.foldable());
@@ -794,8 +793,8 @@ public class OptimizerTests extends ESTestCase {
         Min min2 =  new Min(EMPTY, getFieldAttribute());
 
         OrderBy plan = new OrderBy(EMPTY, new Aggregate(EMPTY, FROM(), emptyList(),
-                Arrays.asList(a("min1", min1), a("min2", min2))),
-            Arrays.asList(
+                asList(a("min1", min1), a("min2", min2))),
+            asList(
                 new Order(EMPTY, min1, OrderDirection.ASC, Order.NullsPosition.LAST),
                 new Order(EMPTY, min2, OrderDirection.ASC, Order.NullsPosition.LAST)));
         LogicalPlan result = new ReplaceMinMaxWithTopHits().apply(plan);
@@ -819,8 +818,8 @@ public class OptimizerTests extends ESTestCase {
         Max max1 = new Max(EMPTY, new FieldAttribute(EMPTY, "str", new EsField("str", KEYWORD, emptyMap(), true)));
         Max max2 =  new Max(EMPTY, getFieldAttribute());
 
-        OrderBy plan = new OrderBy(EMPTY, new Aggregate(EMPTY, FROM(), emptyList(), Arrays.asList(a("max1", max1), a("max2", max2))),
-            Arrays.asList(
+        OrderBy plan = new OrderBy(EMPTY, new Aggregate(EMPTY, FROM(), emptyList(), asList(a("max1", max1), a("max2", max2))),
+            asList(
                 new Order(EMPTY, max1, OrderDirection.ASC, Order.NullsPosition.LAST),
                 new Order(EMPTY, max2, OrderDirection.ASC, Order.NullsPosition.LAST)));
         LogicalPlan result = new ReplaceMinMaxWithTopHits().apply(plan);
@@ -849,8 +848,8 @@ public class OptimizerTests extends ESTestCase {
         Order secondOrderBy = new Order(EMPTY, secondField, OrderDirection.ASC, Order.NullsPosition.LAST);
         
         OrderBy orderByPlan = new OrderBy(EMPTY,
-                new Aggregate(EMPTY, FROM(), Arrays.asList(secondField, firstField), Arrays.asList(secondAlias, firstAlias)),
-                Arrays.asList(firstOrderBy, secondOrderBy));
+                new Aggregate(EMPTY, FROM(), asList(secondField, firstField), asList(secondAlias, firstAlias)),
+                asList(firstOrderBy, secondOrderBy));
         LogicalPlan result = new SortAggregateOnOrderBy().apply(orderByPlan);
         
         assertTrue(result instanceof OrderBy);
@@ -881,8 +880,8 @@ public class OptimizerTests extends ESTestCase {
         Order secondOrderBy = new Order(EMPTY, secondAlias, OrderDirection.ASC, Order.NullsPosition.LAST);
         
         OrderBy orderByPlan = new OrderBy(EMPTY,
-                new Aggregate(EMPTY, FROM(), Arrays.asList(secondAlias, firstAlias), Arrays.asList(secondAlias, firstAlias)),
-                Arrays.asList(firstOrderBy, secondOrderBy));
+                new Aggregate(EMPTY, FROM(), asList(secondAlias, firstAlias), asList(secondAlias, firstAlias)),
+                asList(firstOrderBy, secondOrderBy));
         LogicalPlan result = new SortAggregateOnOrderBy().apply(orderByPlan);
         
         assertTrue(result instanceof OrderBy);
@@ -906,8 +905,8 @@ public class OptimizerTests extends ESTestCase {
     public void testPivotRewrite() {
         FieldAttribute column = getFieldAttribute("pivot");
         FieldAttribute number = getFieldAttribute("number");
-        List<NamedExpression> values = Arrays.asList(new Alias(EMPTY, "ONE", L(1)), new Alias(EMPTY, "TWO", L(2)));
-        List<NamedExpression> aggs = Arrays.asList(new Alias(EMPTY, "AVG", new Avg(EMPTY, number)));
+        List<NamedExpression> values = asList(new Alias(EMPTY, "ONE", L(1)), new Alias(EMPTY, "TWO", L(2)));
+        List<NamedExpression> aggs = asList(new Alias(EMPTY, "AVG", new Avg(EMPTY, number)));
         Pivot pivot = new Pivot(EMPTY, new EsRelation(EMPTY, new EsIndex("table", emptyMap()), false), column, values, aggs);
 
         LogicalPlan result = new RewritePivot().apply(pivot);
@@ -919,7 +918,7 @@ public class OptimizerTests extends ESTestCase {
         assertEquals(In.class, f.condition().getClass());
         In in = (In) f.condition();
         assertEquals(column, in.value());
-        assertEquals(Arrays.asList(L(1), L(2)), in.list());
+        assertEquals(asList(L(1), L(2)), in.list());
     }
 
     /**
@@ -933,7 +932,7 @@ public class OptimizerTests extends ESTestCase {
         FullTextPredicate matchPredicate = new MatchQueryPredicate(EMPTY, matchField, "A", StringUtils.EMPTY);
         FullTextPredicate multiMatchPredicate = new MultiMatchQueryPredicate(EMPTY, "match_field", "A", StringUtils.EMPTY);
         FullTextPredicate stringQueryPredicate = new StringQueryPredicate(EMPTY, "match_field:A", StringUtils.EMPTY);
-        List<FullTextPredicate> predicates = Arrays.asList(matchPredicate, multiMatchPredicate, stringQueryPredicate);
+        List<FullTextPredicate> predicates = asList(matchPredicate, multiMatchPredicate, stringQueryPredicate);
 
         FullTextPredicate left = randomFrom(predicates);
         FullTextPredicate right = randomFrom(predicates);
@@ -946,15 +945,15 @@ public class OptimizerTests extends ESTestCase {
         List<AggregateFunction> aggregates;
         boolean isSimpleStats = randomBoolean();
         if (isSimpleStats) {
-            aggregates = Arrays.asList(new Avg(EMPTY, aggField), new Sum(EMPTY, aggField), new Min(EMPTY, aggField),
+            aggregates = asList(new Avg(EMPTY, aggField), new Sum(EMPTY, aggField), new Min(EMPTY, aggField),
                     new Max(EMPTY, aggField));
         } else {
-            aggregates = Arrays.asList(new StddevPop(EMPTY, aggField), new SumOfSquares(EMPTY, aggField), new VarPop(EMPTY, aggField));
+            aggregates = asList(new StddevPop(EMPTY, aggField), new SumOfSquares(EMPTY, aggField), new VarPop(EMPTY, aggField));
         }
         AggregateFunction firstAggregate = randomFrom(aggregates);
         AggregateFunction secondAggregate = randomValueOtherThan(firstAggregate, () -> randomFrom(aggregates));
         Aggregate aggregatePlan = new Aggregate(EMPTY, filter, singletonList(matchField),
-                Arrays.asList(new Alias(EMPTY, "first", firstAggregate), new Alias(EMPTY, "second", secondAggregate)));
+                asList(new Alias(EMPTY, "first", firstAggregate), new Alias(EMPTY, "second", secondAggregate)));
         LogicalPlan result;
         if (isSimpleStats) {
             result = new ReplaceAggsWithStats().apply(aggregatePlan);
@@ -1001,7 +1000,7 @@ public class OptimizerTests extends ESTestCase {
         Alias aAlias = new Alias(EMPTY, "aAlias", a);
         Alias bAlias = new Alias(EMPTY, "bAlias", b);
         
-        Project p = new Project(EMPTY, FROM(), Arrays.asList(aAlias, bAlias));
+        Project p = new Project(EMPTY, FROM(), asList(aAlias, bAlias));
         Filter f = new Filter(EMPTY, p, new And(EMPTY, greaterThanOf(aAlias.toAttribute(), L(1)),
             greaterThanOf(bAlias.toAttribute(), L(2))));
         
@@ -1023,4 +1022,44 @@ public class OptimizerTests extends ESTestCase {
         gt = (GreaterThan) and.left();
         assertEquals(a, gt.left());
     }
+
+    //
+    // ReplaceSumWithStats rule
+    //
+    public void testSumIsReplacedWithStats() {
+        FieldAttribute fa = getFieldAttribute();
+        Sum sum = new Sum(EMPTY, fa);
+        
+        Alias sumAlias = new Alias(EMPTY, "sum", sum);
+        
+        Aggregate aggregate = new Aggregate(EMPTY, FROM(), emptyList(), asList(sumAlias));
+        LogicalPlan optimizedPlan = new Optimizer().optimize(aggregate);
+        assertTrue(optimizedPlan instanceof Aggregate);
+        Aggregate p = (Aggregate) optimizedPlan; 
+        assertEquals(1, p.aggregates().size());
+        assertTrue(p.aggregates().get(0) instanceof Alias);
+        Alias alias = (Alias) p.aggregates().get(0);
+        assertTrue(alias.child() instanceof InnerAggregate);
+        assertEquals(sum, ((InnerAggregate) alias.child()).inner());
+    }
+
+    /**
+     * Once the root cause of https://github.com/elastic/elasticsearch/issues/45251 is fixed in the <code>sum</code> ES aggregation 
+     * (can differentiate between <code>SUM(all zeroes)</code> and <code>SUM(all nulls)</code>), 
+     * remove the {@link OptimizerTests#testSumIsReplacedWithStats()}, and re-enable the following test.
+     */
+    @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/45251")
+    public void testSumIsNotReplacedWithStats() {
+        FieldAttribute fa = getFieldAttribute();
+        Sum sum = new Sum(EMPTY, fa);
+
+        Alias sumAlias = new Alias(EMPTY, "sum", sum);
+
+        Aggregate aggregate = new Aggregate(EMPTY, FROM(), emptyList(), asList(sumAlias));
+        LogicalPlan optimizedPlan = new Optimizer().optimize(aggregate);
+        assertTrue(optimizedPlan instanceof Aggregate);
+        Aggregate p = (Aggregate) optimizedPlan;
+        assertEquals(1, p.aggregates().size());
+        assertEquals(sumAlias, p.aggregates().get(0));
+    }
 }

+ 19 - 3
x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryTranslatorTests.java

@@ -84,7 +84,6 @@ import org.junit.BeforeClass;
 
 import java.time.ZoneId;
 import java.time.ZonedDateTime;
-import java.util.Arrays;
 import java.util.Collection;
 import java.util.List;
 import java.util.Locale;
@@ -95,6 +94,7 @@ import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 import java.util.stream.Stream;
 
+import static java.util.Arrays.asList;
 import static org.elasticsearch.xpack.ql.type.DataTypes.BOOLEAN;
 import static org.elasticsearch.xpack.ql.type.DataTypes.DATETIME;
 import static org.elasticsearch.xpack.ql.type.DataTypes.DOUBLE;
@@ -153,7 +153,7 @@ public class QueryTranslatorTests extends ESTestCase {
         }
 
         private LogicalPlan parameterizedSql(String sql, SqlTypedParamValue... params) {
-            return analyzer.analyze(parser.createStatement(sql, Arrays.asList(params), DateUtils.UTC), true);
+            return analyzer.analyze(parser.createStatement(sql, asList(params), DateUtils.UTC), true);
         }
     }
 
@@ -1048,7 +1048,7 @@ public class QueryTranslatorTests extends ESTestCase {
         assertFalse(bq.isAnd());
         assertTrue(bq.left() instanceof RangeQuery);
         assertTrue(bq.right() instanceof RangeQuery);
-        List<Tuple<String, RangeQuery>> tuples = Arrays.asList(new Tuple<>(dates[0], (RangeQuery)bq.left()),
+        List<Tuple<String, RangeQuery>> tuples = asList(new Tuple<>(dates[0], (RangeQuery)bq.left()),
             new Tuple<>(dates[1], (RangeQuery) bq.right()));
 
         for (Tuple<String, RangeQuery> t: tuples) {
@@ -2443,4 +2443,20 @@ public class QueryTranslatorTests extends ESTestCase {
         test.accept("PERCENTILE", p -> ((PercentilesAggregationBuilder)p).percentiles());
         test.accept("PERCENTILE_RANK", p -> ((PercentileRanksAggregationBuilder)p).values());
     }
+
+    // Tests the workaround for the SUM(all zeros) = NULL issue raised in https://github.com/elastic/elasticsearch/issues/45251 and
+    // should be removed as soon as root cause is fixed and the sum aggregation results can differentiate between SUM(all zeroes) 
+    // and SUM(all nulls)
+    public void testReplaceSumWithStats() {
+        List<String> testCases = asList(
+            "SELECT keyword, SUM(int) FROM test GROUP BY keyword",
+            "SELECT SUM(int) FROM test",
+            "SELECT * FROM (SELECT some.string, keyword, int FROM test) PIVOT (SUM(int) FOR keyword IN ('a', 'b'))");
+        for (String testCase : testCases) {
+            PhysicalPlan physicalPlan = optimizeAndPlan(testCase);
+            assertEquals(EsQueryExec.class, physicalPlan.getClass());
+            EsQueryExec eqe = (EsQueryExec) physicalPlan;
+            assertThat(eqe.queryContainer().toString().replaceAll("\\s+", ""), containsString("{\"stats\":{\"field\":\"int\"}}"));
+        }
+    }
 }