Browse Source

ESQL: Sum, Min, Max and Avg of constants (#105454)

Allow expressions like
... | STATS sum([1, -9]), sum(null), min(21.0*3), avg([1,2,3])
by substituting sum(const) by mv_sum(const)*count(*) and min(const) by
mv_min(const) (and similarly for max and avg).
Alexander Spies 1 year ago
parent
commit
829ea4d34d

+ 5 - 0
docs/changelog/105454.yaml

@@ -0,0 +1,5 @@
+pr: 105454
+summary: "ESQL: Sum of constants"
+area: ES|QL
+type: enhancement
+issues: []

+ 96 - 0
x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec

@@ -1226,3 +1226,99 @@ FROM employees
 vals:l
 183
 ;
+
+sumOfConst#[skip:-8.13.99,reason:supported in 8.14]
+FROM employees
+| STATS s1 = sum(1), s2point1 = sum(2.1), s_mv = sum([-1, 0, 3]) * 3, s_null = sum(null), rows = count(*)
+;
+
+s1:l | s2point1:d | s_mv:l | s_null:d | rows:l
+100  | 210.0      | 600    | null     | 100
+;
+
+sumOfConstGrouped#[skip:-8.13.99,reason:supported in 8.14]
+FROM employees
+| STATS s2point1 = round(sum(2.1), 1), s_mv = sum([-1, 0, 3]), rows = count(*) by languages
+| SORT languages
+;
+
+s2point1:d | s_mv:l | rows:l | languages:i
+31.5       | 30     | 15     | 1
+39.9       | 38     | 19     | 2
+35.7       | 34     | 17     | 3
+37.8       | 36     | 18     | 4
+44.1       | 42     | 21     | 5
+21.0       | 20     | 10     | null
+;
+
+avgOfConst#[skip:-8.13.99,reason:supported in 8.14]
+FROM employees
+| STATS s1 = avg(1), s_mv = avg([-1, 0, 3]) * 3, s_null = avg(null)
+;
+
+s1:d | s_mv:d | s_null:d
+1.0  | 2.0    | null
+;
+
+avgOfConstGrouped#[skip:-8.13.99,reason:supported in 8.14]
+FROM employees
+| STATS s2point1 = avg(2.1), s_mv = avg([-1, 0, 3]) * 3 by languages
+| SORT languages
+;
+
+s2point1:d | s_mv:d | languages:i
+2.1        | 2.0    | 1
+2.1        | 2.0    | 2
+2.1        | 2.0    | 3
+2.1        | 2.0    | 4
+2.1        | 2.0    | 5
+2.1        | 2.0    | null
+;
+
+minOfConst#[skip:-8.13.99,reason:supported in 8.14]
+FROM employees
+| STATS s1 = min(1), s_mv = min([-1, 0, 3]), s_null = min(null)
+;
+
+s1:i | s_mv:i | s_null:null
+1    | -1     | null
+;
+
+minOfConstGrouped#[skip:-8.13.99,reason:supported in 8.14]
+FROM employees
+| STATS s2point1 = min(2.1), s_mv = min([-1, 0, 3]) by languages
+| SORT languages
+;
+
+s2point1:d | s_mv:i | languages:i
+2.1        | -1     | 1
+2.1        | -1     | 2
+2.1        | -1     | 3
+2.1        | -1     | 4
+2.1        | -1     | 5
+2.1        | -1     | null
+;
+
+maxOfConst#[skip:-8.13.99,reason:supported in 8.14]
+FROM employees
+| STATS s1 = max(1), s_mv = max([-1, 0, 3]), s_null = max(null)
+;
+
+s1:i | s_mv:i | s_null:null
+1    | 3      | null
+;
+
+maxOfConstGrouped#[skip:-8.13.99,reason:supported in 8.14]
+FROM employees
+| STATS s2point1 = max(2.1), s_mv = max([-1, 0, 3]) by languages
+| SORT languages
+;
+
+s2point1:d | s_mv:i | languages:i
+2.1        | 3      | 1
+2.1        | 3      | 2
+2.1        | 3      | 3
+2.1        | 3      | 4
+2.1        | 3      | 5
+2.1        | 3      | null
+;

+ 3 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/SurrogateExpression.java

@@ -15,5 +15,8 @@ import org.elasticsearch.xpack.ql.expression.Expression;
  */
 public interface SurrogateExpression {
 
+    /**
+     * Returns the expression to be replaced by or {@code null} if this cannot be replaced.
+     */
     Expression surrogate();
 }

+ 3 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Avg.java

@@ -10,6 +10,7 @@ package org.elasticsearch.xpack.esql.expression.function.aggregate;
 import org.elasticsearch.xpack.esql.expression.SurrogateExpression;
 import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
 import org.elasticsearch.xpack.esql.expression.function.Param;
+import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvAvg;
 import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div;
 import org.elasticsearch.xpack.ql.expression.Expression;
 import org.elasticsearch.xpack.ql.expression.function.aggregate.AggregateFunction;
@@ -60,6 +61,7 @@ public class Avg extends AggregateFunction implements SurrogateExpression {
     public Expression surrogate() {
         var s = source();
         var field = field();
-        return new Div(s, new Sum(s, field), new Count(s, field), dataType());
+
+        return field().foldable() ? new MvAvg(s, field) : new Div(s, new Sum(s, field), new Count(s, field), dataType());
     }
 }

+ 8 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Max.java

@@ -11,8 +11,10 @@ import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier;
 import org.elasticsearch.compute.aggregation.MaxDoubleAggregatorFunctionSupplier;
 import org.elasticsearch.compute.aggregation.MaxIntAggregatorFunctionSupplier;
 import org.elasticsearch.compute.aggregation.MaxLongAggregatorFunctionSupplier;
+import org.elasticsearch.xpack.esql.expression.SurrogateExpression;
 import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
 import org.elasticsearch.xpack.esql.expression.function.Param;
+import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMax;
 import org.elasticsearch.xpack.ql.expression.Expression;
 import org.elasticsearch.xpack.ql.tree.NodeInfo;
 import org.elasticsearch.xpack.ql.tree.Source;
@@ -20,7 +22,7 @@ import org.elasticsearch.xpack.ql.type.DataType;
 
 import java.util.List;
 
-public class Max extends NumericAggregate {
+public class Max extends NumericAggregate implements SurrogateExpression {
 
     @FunctionInfo(returnType = { "double", "integer", "long" }, description = "The maximum value of a numeric field.", isAggregation = true)
     public Max(Source source, @Param(name = "number", type = { "double", "integer", "long" }) Expression field) {
@@ -61,4 +63,9 @@ public class Max extends NumericAggregate {
     protected AggregatorFunctionSupplier doubleSupplier(List<Integer> inputChannels) {
         return new MaxDoubleAggregatorFunctionSupplier(inputChannels);
     }
+
+    @Override
+    public Expression surrogate() {
+        return field().foldable() ? new MvMax(source(), field()) : null;
+    }
 }

+ 8 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Min.java

@@ -11,8 +11,10 @@ import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier;
 import org.elasticsearch.compute.aggregation.MinDoubleAggregatorFunctionSupplier;
 import org.elasticsearch.compute.aggregation.MinIntAggregatorFunctionSupplier;
 import org.elasticsearch.compute.aggregation.MinLongAggregatorFunctionSupplier;
+import org.elasticsearch.xpack.esql.expression.SurrogateExpression;
 import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
 import org.elasticsearch.xpack.esql.expression.function.Param;
+import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMin;
 import org.elasticsearch.xpack.ql.expression.Expression;
 import org.elasticsearch.xpack.ql.tree.NodeInfo;
 import org.elasticsearch.xpack.ql.tree.Source;
@@ -20,7 +22,7 @@ import org.elasticsearch.xpack.ql.type.DataType;
 
 import java.util.List;
 
-public class Min extends NumericAggregate {
+public class Min extends NumericAggregate implements SurrogateExpression {
 
     @FunctionInfo(returnType = { "double", "integer", "long" }, description = "The minimum value of a numeric field.", isAggregation = true)
     public Min(Source source, @Param(name = "number", type = { "double", "integer", "long" }) Expression field) {
@@ -61,4 +63,9 @@ public class Min extends NumericAggregate {
     protected AggregatorFunctionSupplier doubleSupplier(List<Integer> inputChannels) {
         return new MinDoubleAggregatorFunctionSupplier(inputChannels);
     }
+
+    @Override
+    public Expression surrogate() {
+        return field().foldable() ? new MvMin(source(), field()) : null;
+    }
 }

+ 18 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java

@@ -10,12 +10,18 @@ import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier;
 import org.elasticsearch.compute.aggregation.SumDoubleAggregatorFunctionSupplier;
 import org.elasticsearch.compute.aggregation.SumIntAggregatorFunctionSupplier;
 import org.elasticsearch.compute.aggregation.SumLongAggregatorFunctionSupplier;
+import org.elasticsearch.xpack.esql.expression.SurrogateExpression;
 import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
 import org.elasticsearch.xpack.esql.expression.function.Param;
+import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvSum;
+import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul;
 import org.elasticsearch.xpack.ql.expression.Expression;
+import org.elasticsearch.xpack.ql.expression.Literal;
 import org.elasticsearch.xpack.ql.tree.NodeInfo;
 import org.elasticsearch.xpack.ql.tree.Source;
 import org.elasticsearch.xpack.ql.type.DataType;
+import org.elasticsearch.xpack.ql.type.DataTypes;
+import org.elasticsearch.xpack.ql.util.StringUtils;
 
 import java.util.List;
 
@@ -26,7 +32,7 @@ import static org.elasticsearch.xpack.ql.type.DataTypes.UNSIGNED_LONG;
 /**
  * Sum all values of a field in matching documents.
  */
-public class Sum extends NumericAggregate {
+public class Sum extends NumericAggregate implements SurrogateExpression {
 
     @FunctionInfo(returnType = "long", description = "The sum of a numeric field.", isAggregation = true)
     public Sum(Source source, @Param(name = "number", type = { "double", "integer", "long" }) Expression field) {
@@ -63,4 +69,15 @@ public class Sum extends NumericAggregate {
     protected AggregatorFunctionSupplier doubleSupplier(List<Integer> inputChannels) {
         return new SumDoubleAggregatorFunctionSupplier(inputChannels);
     }
+
+    @Override
+    public Expression surrogate() {
+        var s = source();
+        var field = field();
+
+        // SUM(const) is equivalent to MV_SUM(const)*COUNT(*).
+        return field.foldable()
+            ? new Mul(s, new MvSum(s, field), new Count(s, new Literal(s, StringUtils.WILDCARD, DataTypes.KEYWORD)))
+            : null;
+    }
 }

+ 41 - 21
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java

@@ -34,6 +34,7 @@ import org.elasticsearch.xpack.ql.expression.Alias;
 import org.elasticsearch.xpack.ql.expression.Attribute;
 import org.elasticsearch.xpack.ql.expression.AttributeMap;
 import org.elasticsearch.xpack.ql.expression.AttributeSet;
+import org.elasticsearch.xpack.ql.expression.EmptyAttribute;
 import org.elasticsearch.xpack.ql.expression.Expression;
 import org.elasticsearch.xpack.ql.expression.ExpressionSet;
 import org.elasticsearch.xpack.ql.expression.Expressions;
@@ -107,6 +108,23 @@ public class LogicalPlanOptimizer extends ParameterizedRuleExecutor<LogicalPlan,
         return rules();
     }
 
+    protected static Batch<LogicalPlan> substitutions() {
+        return new Batch<>(
+            "Substitutions",
+            Limiter.ONCE,
+            // first extract nested aggs top-level - this simplifies the rest of the rules
+            new ReplaceStatsAggExpressionWithEval(),
+            // second extract nested aggs inside of them
+            new ReplaceStatsNestedExpressionWithEval(),
+            // lastly replace surrogate functions
+            new SubstituteSurrogates(),
+            new ReplaceRegexMatch(),
+            new ReplaceAliasingEvalWithProject(),
+            new SkipQueryOnEmptyMappings()
+            // new NormalizeAggregate(), - waits on https://github.com/elastic/elasticsearch/issues/100634
+        );
+    }
+
     protected static Batch<LogicalPlan> operators() {
         return new Batch<>(
             "Operator Optimization",
@@ -150,26 +168,11 @@ public class LogicalPlanOptimizer extends ParameterizedRuleExecutor<LogicalPlan,
     }
 
     protected static List<Batch<LogicalPlan>> rules() {
-        var substitutions = new Batch<>(
-            "Substitutions",
-            Limiter.ONCE,
-            // first extract nested aggs top-level - this simplifies the rest of the rules
-            new ReplaceStatsAggExpressionWithEval(),
-            // second extract nested aggs inside of them
-            new ReplaceStatsNestedExpressionWithEval(),
-            // lastly replace surrogate functions
-            new SubstituteSurrogates(),
-            new ReplaceRegexMatch(),
-            new ReplaceAliasingEvalWithProject(),
-            new SkipQueryOnEmptyMappings()
-            // new NormalizeAggregate(), - waits on https://github.com/elastic/elasticsearch/issues/100634
-        );
-
         var skip = new Batch<>("Skip Compute", new SkipQueryOnLimitZero());
         var defaultTopN = new Batch<>("Add default TopN", new AddDefaultTopN());
         var label = new Batch<>("Set as Optimized", Limiter.ONCE, new SetAsOptimized());
 
-        return asList(substitutions, operators(), skip, cleanup(), defaultTopN, label);
+        return asList(substitutions(), operators(), skip, cleanup(), defaultTopN, label);
     }
 
     // TODO: currently this rule only works for aggregate functions (AVG)
@@ -191,8 +194,10 @@ public class LogicalPlanOptimizer extends ParameterizedRuleExecutor<LogicalPlan,
 
             // first pass to check existing aggregates (to avoid duplication and alias waste)
             for (NamedExpression agg : aggs) {
-                if (Alias.unwrap(agg) instanceof AggregateFunction af && af instanceof SurrogateExpression == false) {
-                    aggFuncToAttr.put(af, agg.toAttribute());
+                if (Alias.unwrap(agg) instanceof AggregateFunction af) {
+                    if ((af instanceof SurrogateExpression se && se.surrogate() != null) == false) {
+                        aggFuncToAttr.put(af, agg.toAttribute());
+                    }
                 }
             }
 
@@ -200,7 +205,7 @@ public class LogicalPlanOptimizer extends ParameterizedRuleExecutor<LogicalPlan,
             // 0. check list of surrogate expressions
             for (NamedExpression agg : aggs) {
                 Expression e = Alias.unwrap(agg);
-                if (e instanceof SurrogateExpression sf) {
+                if (e instanceof SurrogateExpression sf && sf.surrogate() != null) {
                     changed = true;
                     Expression s = sf.surrogate();
 
@@ -240,9 +245,22 @@ public class LogicalPlanOptimizer extends ParameterizedRuleExecutor<LogicalPlan,
             LogicalPlan plan = aggregate;
             if (changed) {
                 var source = aggregate.source();
-                plan = new Aggregate(aggregate.source(), aggregate.child(), aggregate.groupings(), newAggs);
+                if (newAggs.isEmpty() == false) {
+                    plan = new Aggregate(source, aggregate.child(), aggregate.groupings(), newAggs);
+                } else {
+                    // All aggs actually have been surrogates for (foldable) expressions, e.g.
+                    // \_Aggregate[[],[AVG([1, 2][INTEGER]) AS s]]
+                    // Replace by a local relation with one row, followed by an eval, e.g.
+                    // \_Eval[[MVAVG([1, 2][INTEGER]) AS s]]
+                    // \_LocalRelation[[{e}#21],[ConstantNullBlock[positions=1]]]
+                    plan = new LocalRelation(
+                        source,
+                        List.of(new EmptyAttribute(source)),
+                        LocalSupplier.of(new Block[] { BlockUtils.constantBlock(PlannerUtils.NON_BREAKING_BLOCK_FACTORY, null, 1) })
+                    );
+                }
                 // 5. force the initial projection in place
-                if (transientEval.size() > 0) {
+                if (transientEval.isEmpty() == false) {
                     plan = new Eval(source, plan, transientEval);
                     // project away transient fields and re-enforce the original order using references (not copies) to the original aggs
                     // this works since the replaced aliases have their nameId copied to avoid having to update all references (which has
@@ -500,6 +518,8 @@ public class LogicalPlanOptimizer extends ParameterizedRuleExecutor<LogicalPlan,
 
             plan = plan.transformUp(p -> {
                 // Apply the replacement inside Filter and Eval (which shouldn't make a difference)
+                // TODO: also allow aggregates once aggs on constants are supported.
+                // C.f. https://github.com/elastic/elasticsearch/issues/100634
                 if (p instanceof Filter || p instanceof Eval) {
                     p = p.transformExpressionsOnly(ReferenceAttribute.class, replaceReference);
                 }

+ 2 - 6
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java

@@ -11,11 +11,9 @@ import org.elasticsearch.compute.aggregation.IntermediateStateDesc;
 import org.elasticsearch.compute.data.ElementType;
 import org.elasticsearch.core.Tuple;
 import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
-import org.elasticsearch.xpack.esql.expression.SurrogateExpression;
 import org.elasticsearch.xpack.esql.expression.function.aggregate.Count;
 import org.elasticsearch.xpack.esql.expression.function.aggregate.CountDistinct;
 import org.elasticsearch.xpack.esql.expression.function.aggregate.Max;
-import org.elasticsearch.xpack.esql.expression.function.aggregate.Median;
 import org.elasticsearch.xpack.esql.expression.function.aggregate.MedianAbsoluteDeviation;
 import org.elasticsearch.xpack.esql.expression.function.aggregate.Min;
 import org.elasticsearch.xpack.esql.expression.function.aggregate.NumericAggregate;
@@ -43,7 +41,6 @@ import java.lang.invoke.MethodType;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
-import java.util.function.Predicate;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
 
@@ -55,12 +52,11 @@ public class AggregateMapper {
     static final List<String> NUMERIC = List.of("Int", "Long", "Double");
     static final List<String> SPATIAL = List.of("GeoPoint", "CartesianPoint");
 
-    /** List of all ESQL agg functions. */
+    /** List of all mappable ESQL agg functions (excludes surrogates like AVG = SUM/COUNT). */
     static final List<? extends Class<? extends Function>> AGG_FUNCTIONS = List.of(
         Count.class,
         CountDistinct.class,
         Max.class,
-        Median.class,
         MedianAbsoluteDeviation.class,
         Min.class,
         Percentile.class,
@@ -79,7 +75,7 @@ public class AggregateMapper {
     private final HashMap<Expression, List<? extends NamedExpression>> cache = new HashMap<>();
 
     AggregateMapper() {
-        this(AGG_FUNCTIONS.stream().filter(Predicate.not(SurrogateExpression.class::isAssignableFrom)).toList());
+        this(AGG_FUNCTIONS);
     }
 
     AggregateMapper(List<? extends Class<? extends Function>> aggregateFunctionClasses) {

+ 191 - 1
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java

@@ -112,9 +112,11 @@ import org.elasticsearch.xpack.ql.type.EsField;
 import org.junit.BeforeClass;
 
 import java.lang.reflect.Constructor;
+import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.function.Function;
 
 import static java.util.Arrays.asList;
 import static java.util.Collections.emptyList;
@@ -173,6 +175,19 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
     private static Analyzer analyzerAirports;
     private static EnrichResolution enrichResolution;
 
+    private static class SubstitutionOnlyOptimizer extends LogicalPlanOptimizer {
+        static SubstitutionOnlyOptimizer INSTANCE = new SubstitutionOnlyOptimizer(new LogicalOptimizerContext(EsqlTestUtils.TEST_CFG));
+
+        SubstitutionOnlyOptimizer(LogicalOptimizerContext optimizerContext) {
+            super(optimizerContext);
+        }
+
+        @Override
+        protected List<Batch<LogicalPlan>> batches() {
+            return List.of(substitutions());
+        }
+    }
+
     @BeforeClass
     public static void init() {
         parser = new EsqlParser();
@@ -3272,6 +3287,177 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
         assertThat(Expressions.attribute(fields.get(1)), is(Expressions.attribute(sum_argument)));
     }
 
+    /**
+     * Expects after running the {@link LogicalPlanOptimizer#substitutions()}:
+     *
+     * Limit[1000[INTEGER]]
+     * \_EsqlProject[[s{r}#3, s_expr{r}#5, s_null{r}#7, w{r}#10]]
+     *   \_Project[[s{r}#3, s_expr{r}#5, s_null{r}#7, w{r}#10]]
+     *     \_Eval[[MVSUM([1, 2][INTEGER]) * $$COUNT$s$0{r}#25 AS s, MVSUM(314.0[DOUBLE] / 100[INTEGER]) * $$COUNT$s$0{r}#25 AS s
+     * _expr, MVSUM(null[NULL]) * $$COUNT$s$0{r}#25 AS s_null]]
+     *       \_Aggregate[[w{r}#10],[COUNT(*[KEYWORD]) AS $$COUNT$s$0, w{r}#10]]
+     *         \_Eval[[emp_no{f}#15 % 2[INTEGER] AS w]]
+     *           \_EsRelation[test][_meta_field{f}#21, emp_no{f}#15, first_name{f}#16, ..]
+     */
+    public void testSumOfLiteral() {
+        var plan = plan("""
+            from test
+            | stats s = sum([1,2]),
+                    s_expr = sum(314.0/100),
+                    s_null = sum(null)
+                    by w = emp_no % 2
+            | keep s, s_expr, s_null, w
+            """, SubstitutionOnlyOptimizer.INSTANCE);
+
+        var limit = as(plan, Limit.class);
+        var esqlProject = as(limit.child(), EsqlProject.class);
+        var project = as(esqlProject.child(), Project.class);
+        var eval = as(project.child(), Eval.class);
+        var agg = as(eval.child(), Aggregate.class);
+
+        var exprs = eval.fields();
+        // s = count(*) * 3
+        var s = as(exprs.get(0), Alias.class);
+        assertThat(s.name(), equalTo("s"));
+        var mul = as(s.child(), Mul.class);
+        var mvSum = as(mul.left(), MvSum.class);
+        assertThat(mvSum.fold(), equalTo(3));
+        var count = as(mul.right(), ReferenceAttribute.class);
+        assertThat(count.name(), equalTo("$$COUNT$s$0"));
+
+        // s_expr = count(*) * 3.14
+        var s_expr = as(exprs.get(1), Alias.class);
+        assertThat(s_expr.name(), equalTo("s_expr"));
+        var mul_expr = as(s_expr.child(), Mul.class);
+        var mvSum_expr = as(mul_expr.left(), MvSum.class);
+        assertThat(mvSum_expr.fold(), equalTo(3.14));
+        var count_expr = as(mul_expr.right(), ReferenceAttribute.class);
+        assertThat(count_expr.name(), equalTo("$$COUNT$s$0"));
+
+        // s_null = null
+        var s_null = as(exprs.get(2), Alias.class);
+        assertThat(s_null.name(), equalTo("s_null"));
+        var mul_null = as(s_null.child(), Mul.class);
+        var mvSum_null = as(mul_null.left(), MvSum.class);
+        assertThat(mvSum_null.field(), equalTo(NULL));
+        var count_null = as(mul_null.right(), ReferenceAttribute.class);
+        assertThat(count_null.name(), equalTo("$$COUNT$s$0"));
+
+        var count_agg = as(Alias.unwrap(agg.aggregates().get(0)), Count.class);
+        assertThat(count_agg.children().get(0), instanceOf(Literal.class));
+        var w = as(Alias.unwrap(agg.groupings().get(0)), ReferenceAttribute.class);
+        assertThat(w.name(), equalTo("w"));
+    }
+
+    private record AggOfLiteralTestCase(
+        String aggFunctionName,
+        Class<? extends Expression> substitution,
+        Function<int[], Object> aggMultiValue
+    ) {};
+
+    private static List<AggOfLiteralTestCase> AGG_OF_CONST_CASES = List.of(
+        new AggOfLiteralTestCase("avg", MvAvg.class, ints -> ((double) Arrays.stream(ints).sum()) / ints.length),
+        new AggOfLiteralTestCase("min", MvMin.class, ints -> Arrays.stream(ints).min().getAsInt()),
+        new AggOfLiteralTestCase("max", MvMax.class, ints -> Arrays.stream(ints).max().getAsInt())
+    );
+
+    /**
+     * Aggs of literals in case that the agg can be simply replaced by a corresponding mv-function;
+     * e.g. avg([1,2,3]) which is equivalent to mv_avg([1,2,3]).
+     *
+     * Expects after running the {@link LogicalPlanOptimizer#substitutions()}:
+     *
+     * Limit[1000[INTEGER]]
+     * \_EsqlProject[[s{r}#3, s_expr{r}#5, s_null{r}#7]]
+     *   \_Project[[s{r}#3, s_expr{r}#5, s_null{r}#7]]
+     *     \_Eval[[MVAVG([1, 2][INTEGER]) AS s, MVAVG(314.0[DOUBLE] / 100[INTEGER]) AS s_expr, MVAVG(null[NULL]) AS s_null]]
+     *       \_LocalRelation[[{e}#21],[ConstantNullBlock[positions=1]]]
+     */
+    public void testAggOfLiteral() {
+        for (AggOfLiteralTestCase testCase : AGG_OF_CONST_CASES) {
+            String query = LoggerMessageFormat.format(null, """
+                from test
+                | stats s = {}([1,2]),
+                        s_expr = {}(314.0/100),
+                        s_null = {}(null)
+                | keep s, s_expr, s_null
+                """, testCase.aggFunctionName, testCase.aggFunctionName, testCase.aggFunctionName);
+
+            var plan = plan(query, SubstitutionOnlyOptimizer.INSTANCE);
+
+            var limit = as(plan, Limit.class);
+            var esqlProject = as(limit.child(), EsqlProject.class);
+            var project = as(esqlProject.child(), Project.class);
+            var eval = as(project.child(), Eval.class);
+            var singleRowRelation = as(eval.child(), LocalRelation.class);
+            var singleRow = singleRowRelation.supplier().get();
+            assertThat(singleRow.length, equalTo(1));
+            assertThat(singleRow[0].getPositionCount(), equalTo(1));
+
+            var exprs = eval.fields();
+            var s = as(exprs.get(0), Alias.class);
+            assertThat(s.child(), instanceOf(testCase.substitution));
+            assertThat(s.child().fold(), equalTo(testCase.aggMultiValue.apply(new int[] { 1, 2 })));
+            var s_expr = as(exprs.get(1), Alias.class);
+            assertThat(s_expr.child(), instanceOf(testCase.substitution));
+            assertThat(s_expr.child().fold(), equalTo(3.14));
+            var s_null = as(exprs.get(2), Alias.class);
+            assertThat(s_null.child(), instanceOf(testCase.substitution));
+            assertThat(s_null.child().fold(), equalTo(null));
+        }
+    }
+
+    /**
+     * Like {@link LogicalPlanOptimizerTests#testAggOfLiteral()} but with a grouping key.
+     *
+     * Expects after running the {@link LogicalPlanOptimizer#substitutions()}:
+     *
+     * Limit[1000[INTEGER]]
+     * \_EsqlProject[[s{r}#3, s_expr{r}#5, s_null{r}#7, emp_no{f}#13]]
+     *   \_Project[[s{r}#3, s_expr{r}#5, s_null{r}#7, emp_no{f}#13]]
+     *     \_Eval[[MVAVG([1, 2][INTEGER]) AS s, MVAVG(314.0[DOUBLE] / 100[INTEGER]) AS s_expr, MVAVG(null[NULL]) AS s_null]]
+     *       \_Aggregate[[emp_no{f}#13],[emp_no{f}#13]]
+     *         \_EsRelation[test][_meta_field{f}#19, emp_no{f}#13, first_name{f}#14, ..]
+     */
+    public void testAggOfLiteralGrouped() {
+        for (AggOfLiteralTestCase testCase : AGG_OF_CONST_CASES) {
+            String query = LoggerMessageFormat.format(null, """
+                    from test
+                    | stats s = {}([1,2]),
+                            s_expr = {}(314.0/100),
+                            s_null = {}(null)
+                            by emp_no
+                    | keep s, s_expr, s_null, emp_no
+                """, testCase.aggFunctionName, testCase.aggFunctionName, testCase.aggFunctionName);
+
+            var plan = plan(query, SubstitutionOnlyOptimizer.INSTANCE);
+
+            var limit = as(plan, Limit.class);
+            var esqlProject = as(limit.child(), EsqlProject.class);
+            var project = as(esqlProject.child(), Project.class);
+            var eval = as(project.child(), Eval.class);
+            var agg = as(eval.child(), Aggregate.class);
+            assertThat(agg.child(), instanceOf(EsRelation.class));
+
+            // Assert exprs
+            var exprs = eval.fields();
+
+            var s = as(exprs.get(0), Alias.class);
+            assertThat(s.child(), instanceOf(testCase.substitution));
+            assertThat(s.child().fold(), equalTo(testCase.aggMultiValue.apply(new int[] { 1, 2 })));
+            var s_expr = as(exprs.get(1), Alias.class);
+            assertThat(s_expr.child(), instanceOf(testCase.substitution));
+            assertThat(s_expr.child().fold(), equalTo(3.14));
+            var s_null = as(exprs.get(2), Alias.class);
+            assertThat(s_null.child(), instanceOf(testCase.substitution));
+            assertThat(s_null.child().fold(), equalTo(null));
+
+            // Assert that the aggregate only does the grouping by emp_no
+            assertThat(Expressions.names(agg.groupings()), contains("emp_no"));
+            assertThat(agg.aggregates().size(), equalTo(1));
+        }
+    }
+
     public void testEmptyMappingIndex() {
         EsIndex empty = new EsIndex("empty_test", emptyMap(), emptySet());
         IndexResolution getIndexResultAirports = IndexResolution.valid(empty);
@@ -3455,9 +3641,13 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
     }
 
     private LogicalPlan plan(String query) {
+        return plan(query, logicalOptimizer);
+    }
+
+    private LogicalPlan plan(String query, LogicalPlanOptimizer optimizer) {
         var analyzed = analyzer.analyze(parser.createStatement(query));
         // System.out.println(analyzed);
-        var optimized = logicalOptimizer.optimize(analyzed);
+        var optimized = optimizer.optimize(analyzed);
         // System.out.println(optimized);
         return optimized;
     }