Browse Source

Change how grouping aggs track null (ESQL-1328)

This changes how grouping aggs track `null` to fix bugs in the old
tracking caused by `null`s arriving after non-null values and things
like that. Now there are two "modes" from the `X-ArrayState` classes:
1. A mode where we do not track null and rely on the `selected` list of
buckets produced by the `BlockHash`. This mode is appropriate right up
until you see the first `null` value.
2. A mode where we do track which values are null and can reply with
`null` even for values `BlockHash` selects if we've only ever added
`null` values to that slot. This is appropriate for all data, but less
efficient so we only transition to it when we receive our first `Block`
containing `null` values.

The transition is *interesting* because we need to know which values
aren't `null`. Luckily, `BlockHash` has that information. Usually it's
just "everything in this range" but for `boolean` values it isn't. And,
I expect, for ordinals it won't be either. So we just ask `BlockHash` at
the moment of the transition. That required changing the interface
around a little bit so we could ask *before* the group keys for the next
block was added.
Nik Everett 2 years ago
parent
commit
1b75dee430
49 changed files with 890 additions and 1059 deletions
  1. 30 47
      x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java
  2. 5 0
      x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Types.java
  3. 35 52
      x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DoubleArrayState.java
  4. 35 52
      x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/IntArrayState.java
  5. 39 61
      x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/LongArrayState.java
  6. 8 28
      x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctBooleanGroupingAggregatorFunction.java
  7. 8 30
      x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctBytesRefGroupingAggregatorFunction.java
  8. 8 28
      x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctDoubleGroupingAggregatorFunction.java
  9. 8 28
      x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctIntGroupingAggregatorFunction.java
  10. 8 28
      x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctLongGroupingAggregatorFunction.java
  11. 14 38
      x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxDoubleGroupingAggregatorFunction.java
  12. 14 38
      x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxIntGroupingAggregatorFunction.java
  13. 14 38
      x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxLongGroupingAggregatorFunction.java
  14. 8 28
      x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationDoubleGroupingAggregatorFunction.java
  15. 8 28
      x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationIntGroupingAggregatorFunction.java
  16. 8 28
      x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationLongGroupingAggregatorFunction.java
  17. 14 38
      x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinDoubleGroupingAggregatorFunction.java
  18. 14 38
      x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinIntGroupingAggregatorFunction.java
  19. 14 38
      x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinLongGroupingAggregatorFunction.java
  20. 8 28
      x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileDoubleGroupingAggregatorFunction.java
  21. 8 28
      x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileIntGroupingAggregatorFunction.java
  22. 8 28
      x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileLongGroupingAggregatorFunction.java
  23. 8 28
      x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumDoubleGroupingAggregatorFunction.java
  24. 14 38
      x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumIntGroupingAggregatorFunction.java
  25. 14 38
      x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunction.java
  26. 53 0
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/AbstractArrayState.java
  27. 5 15
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountDistinctBooleanAggregator.java
  28. 41 36
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountGroupingAggregatorFunction.java
  29. 2 2
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregator.java
  30. 1 1
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunction.java
  31. 4 10
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/HllStates.java
  32. 13 10
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/QuantileStates.java
  33. 38 0
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SeenGroupIds.java
  34. 15 52
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SumDoubleAggregator.java
  35. 39 61
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ArrayState.java.st
  36. 7 1
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java
  37. 13 0
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BooleanBlockHash.java
  38. 7 0
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BytesRefBlockHash.java
  39. 7 0
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BytesRefLongBlockHash.java
  40. 7 0
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/DoubleBlockHash.java
  41. 7 0
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/IntBlockHash.java
  42. 7 0
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/LongBlockHash.java
  43. 7 0
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/LongLongBlockHash.java
  44. 7 0
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/PackedValuesBlockHash.java
  45. 2 2
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java
  46. 30 8
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OrdinalsGroupingOperator.java
  47. 192 0
      x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/ArrayStateTests.java
  48. 34 5
      x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java
  49. 0 2
      x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionIT.java

+ 30 - 47
x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java

@@ -23,7 +23,6 @@ import java.util.List;
 import java.util.Locale;
 import java.util.function.Consumer;
 import java.util.stream.Collectors;
-import java.util.stream.Stream;
 
 import javax.lang.model.element.ExecutableElement;
 import javax.lang.model.element.Modifier;
@@ -51,6 +50,7 @@ import static org.elasticsearch.compute.gen.Types.LIST_INTEGER;
 import static org.elasticsearch.compute.gen.Types.LONG_BLOCK;
 import static org.elasticsearch.compute.gen.Types.LONG_VECTOR;
 import static org.elasticsearch.compute.gen.Types.PAGE;
+import static org.elasticsearch.compute.gen.Types.SEEN_GROUP_IDS;
 import static org.elasticsearch.compute.gen.Types.blockType;
 import static org.elasticsearch.compute.gen.Types.vectorType;
 
@@ -93,10 +93,10 @@ public class GroupingAggregatorImplementer {
         this.combineIntermediate = findMethod(declarationType, "combineIntermediate");
         this.evaluateFinal = findMethod(declarationType, "evaluateFinal");
         this.valuesIsBytesRef = BYTES_REF.equals(TypeName.get(combine.getParameters().get(combine.getParameters().size() - 1).asType()));
-        List<Parameter> createParameters = init.getParameters().stream().map(Parameter::from).toList();
-        this.createParameters = createParameters.stream().anyMatch(p -> p.type().equals(BIG_ARRAYS))
-            ? createParameters
-            : Stream.concat(Stream.of(new Parameter(BIG_ARRAYS, "bigArrays")), createParameters.stream()).toList();
+        this.createParameters = init.getParameters().stream().map(Parameter::from).collect(Collectors.toList());
+        if (false == createParameters.stream().anyMatch(p -> p.type().equals(BIG_ARRAYS))) {
+            createParameters.add(0, new Parameter(BIG_ARRAYS, "bigArrays"));
+        }
 
         this.implementation = ClassName.get(
             elements.getPackageOf(declarationType).toString(),
@@ -161,10 +161,8 @@ public class GroupingAggregatorImplementer {
         builder.addMethod(prepareProcessPage());
         builder.addMethod(addRawInputLoop(LONG_VECTOR, valueBlockType(init, combine)));
         builder.addMethod(addRawInputLoop(LONG_VECTOR, valueVectorType(init, combine)));
-        builder.addMethod(addRawInputLoop(LONG_VECTOR, BLOCK));
         builder.addMethod(addRawInputLoop(LONG_BLOCK, valueBlockType(init, combine)));
         builder.addMethod(addRawInputLoop(LONG_BLOCK, valueVectorType(init, combine)));
-        builder.addMethod(addRawInputLoop(LONG_BLOCK, BLOCK));
         builder.addMethod(addIntermediateInput());
         builder.addMethod(addIntermediateRowInput());
         builder.addMethod(evaluateIntermediate());
@@ -250,21 +248,24 @@ public class GroupingAggregatorImplementer {
     private MethodSpec prepareProcessPage() {
         MethodSpec.Builder builder = MethodSpec.methodBuilder("prepareProcessPage");
         builder.addAnnotation(Override.class).addModifiers(Modifier.PUBLIC).returns(GROUPING_AGGREGATOR_FUNCTION_ADD_INPUT);
-        builder.addParameter(PAGE, "page");
+        builder.addParameter(SEEN_GROUP_IDS, "seenGroupIds").addParameter(PAGE, "page");
 
         builder.addStatement("$T uncastValuesBlock = page.getBlock(channels.get(0))", BLOCK);
+
         builder.beginControlFlow("if (uncastValuesBlock.areAllValuesNull())");
         {
-            builder.addStatement(
-                "return $L",
-                addInput(b -> b.addStatement("addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock)"))
-            );
+            builder.addStatement("state.enableGroupIdTracking(seenGroupIds)");
+            builder.addStatement("return $L", addInput(b -> {}));
         }
         builder.endControlFlow();
+
         builder.addStatement("$T valuesBlock = ($T) uncastValuesBlock", valueBlockType(init, combine), valueBlockType(init, combine));
         builder.addStatement("$T valuesVector = valuesBlock.asVector()", valueVectorType(init, combine));
         builder.beginControlFlow("if (valuesVector == null)");
         {
+            builder.beginControlFlow("if (valuesBlock.mayHaveNulls())");
+            builder.addStatement("state.enableGroupIdTracking(seenGroupIds)");
+            builder.endControlFlow();
             builder.addStatement("return $L", addInput(b -> b.addStatement("addRawInput(positionOffset, groupIds, valuesBlock)")));
         }
         builder.endControlFlow();
@@ -299,18 +300,8 @@ public class GroupingAggregatorImplementer {
      */
     private MethodSpec addRawInputLoop(TypeName groupsType, TypeName valuesType) {
         boolean groupsIsBlock = groupsType.toString().endsWith("Block");
-        enum ValueType {
-            VECTOR,
-            TYPED_BLOCK,
-            NULL_ONLY_BLOCK
-        }
-        ValueType valueType = valuesType.equals(BLOCK) ? ValueType.NULL_ONLY_BLOCK
-            : valuesType.toString().endsWith("Block") ? ValueType.TYPED_BLOCK
-            : ValueType.VECTOR;
+        boolean valuesIsBlock = valuesType.toString().endsWith("Block");
         String methodName = "addRawInput";
-        if (valueType == ValueType.NULL_ONLY_BLOCK) {
-            methodName += "AllNulls";
-        }
         MethodSpec.Builder builder = MethodSpec.methodBuilder(methodName);
         builder.addModifiers(Modifier.PRIVATE);
         builder.addParameter(TypeName.INT, "positionOffset").addParameter(groupsType, "groups").addParameter(valuesType, "values");
@@ -333,23 +324,17 @@ public class GroupingAggregatorImplementer {
                 builder.addStatement("int groupId = Math.toIntExact(groups.getLong(groupPosition))");
             }
 
-            switch (valueType) {
-                case VECTOR -> combineRawInput(builder, "values", "groupPosition + positionOffset");
-                case TYPED_BLOCK -> {
-                    builder.beginControlFlow("if (values.isNull(groupPosition + positionOffset))");
-                    builder.addStatement("state.putNull(groupId)");
-                    builder.addStatement("continue");
-                    builder.endControlFlow();
-                    builder.addStatement("int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset)");
-                    builder.addStatement("int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset)");
-                    builder.beginControlFlow("for (int v = valuesStart; v < valuesEnd; v++)");
-                    combineRawInput(builder, "values", "v");
-                    builder.endControlFlow();
-                }
-                case NULL_ONLY_BLOCK -> {
-                    builder.addStatement("assert values.isNull(groupPosition + positionOffset)");
-                    builder.addStatement("state.putNull(groupPosition + positionOffset)");
-                }
+            if (valuesIsBlock) {
+                builder.beginControlFlow("if (values.isNull(groupPosition + positionOffset))");
+                builder.addStatement("continue");
+                builder.endControlFlow();
+                builder.addStatement("int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset)");
+                builder.addStatement("int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset)");
+                builder.beginControlFlow("for (int v = valuesStart; v < valuesEnd; v++)");
+                combineRawInput(builder, "values", "v");
+                builder.endControlFlow();
+            } else {
+                combineRawInput(builder, "values", "groupPosition + positionOffset");
             }
 
             if (groupsIsBlock) {
@@ -391,7 +376,7 @@ public class GroupingAggregatorImplementer {
         String offsetVariable
     ) {
         builder.addStatement(
-            "state.set($T.combine(state.getOrDefault(groupId), $L.$L($L)), groupId)",
+            "state.set(groupId, $T.combine(state.getOrDefault(groupId), $L.$L($L)))",
             declarationType,
             blockVariable,
             secondParameterGetter,
@@ -426,6 +411,7 @@ public class GroupingAggregatorImplementer {
         builder.addParameter(LONG_VECTOR, "groups");
         builder.addParameter(PAGE, "page");
 
+        builder.addStatement("state.enableGroupIdTracking(new $T.Empty())", SEEN_GROUP_IDS);
         builder.addStatement("assert channels.size() == intermediateBlockCount()");
         int count = 0;
         for (var interState : intermediateState) {
@@ -461,13 +447,11 @@ public class GroupingAggregatorImplementer {
                     var name = intermediateState.get(0).name();
                     var m = vectorAccessorName(intermediateState.get(0).elementType());
                     builder.addStatement(
-                        "state.set($T.combine(state.getOrDefault(groupId), $L.$L(groupPosition + positionOffset)), groupId)",
+                        "state.set(groupId, $T.combine($L.$L(groupPosition + positionOffset), state.getOrDefault(groupId)))",
                         declarationType,
                         name,
                         m
                     );
-                    builder.nextControlFlow("else");
-                    builder.addStatement("state.putNull(groupId)");
                     builder.endControlFlow();
                 }
             } else {
@@ -493,9 +477,7 @@ public class GroupingAggregatorImplementer {
     private void combineStates(MethodSpec.Builder builder) {
         if (combineStates == null) {
             builder.beginControlFlow("if (inState.hasValue(position))");
-            builder.addStatement("state.set($T.combine(state.getOrDefault(groupId), inState.get(position)), groupId)", declarationType);
-            builder.nextControlFlow("else");
-            builder.addStatement("state.putNull(groupId)");
+            builder.addStatement("state.set(groupId, $T.combine(state.getOrDefault(groupId), inState.get(position)))", declarationType);
             builder.endControlFlow();
             return;
         }
@@ -512,6 +494,7 @@ public class GroupingAggregatorImplementer {
         }
         builder.endControlFlow();
         builder.addStatement("$T inState = (($T) input).state", stateType, implementation);
+        builder.addStatement("state.enableGroupIdTracking(new $T.Empty())", SEEN_GROUP_IDS);
         combineStates(builder);
         return builder.build();
     }

+ 5 - 0
x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Types.java

@@ -71,6 +71,10 @@ public class Types {
     static final ClassName LONG_CONSTANT_VECTOR = ClassName.get(DATA_PACKAGE, "ConstantLongVector");
     static final ClassName DOUBLE_CONSTANT_VECTOR = ClassName.get(DATA_PACKAGE, "ConstantDoubleVector");
 
+    static final ClassName INT_ARRAY_STATE = ClassName.get(AGGREGATION_PACKAGE, "IntArrayState");
+    static final ClassName LONG_ARRAY_STATE = ClassName.get(AGGREGATION_PACKAGE, "LongArrayState");
+    static final ClassName DOUBLE_ARRAY_STATE = ClassName.get(AGGREGATION_PACKAGE, "DoubleArrayState");
+
     static final ClassName AGGREGATOR_FUNCTION = ClassName.get(AGGREGATION_PACKAGE, "AggregatorFunction");
     static final ClassName AGGREGATOR_FUNCTION_SUPPLIER = ClassName.get(AGGREGATION_PACKAGE, "AggregatorFunctionSupplier");
     static final ClassName GROUPING_AGGREGATOR_FUNCTION = ClassName.get(AGGREGATION_PACKAGE, "GroupingAggregatorFunction");
@@ -79,6 +83,7 @@ public class Types {
         "GroupingAggregatorFunction",
         "AddInput"
     );
+    static final ClassName SEEN_GROUP_IDS = ClassName.get(AGGREGATION_PACKAGE, "SeenGroupIds");
 
     static final ClassName INTERMEDIATE_STATE_DESC = ClassName.get(AGGREGATION_PACKAGE, "IntermediateStateDesc");
     static final TypeName LIST_AGG_FUNC_DESC = ParameterizedTypeName.get(ClassName.get(List.class), INTERMEDIATE_STATE_DESC);

+ 35 - 52
x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/DoubleArrayState.java

@@ -8,7 +8,6 @@
 package org.elasticsearch.compute.aggregation;
 
 import org.elasticsearch.common.util.BigArrays;
-import org.elasticsearch.common.util.BitArray;
 import org.elasticsearch.common.util.DoubleArray;
 import org.elasticsearch.compute.data.Block;
 import org.elasticsearch.compute.data.BooleanBlock;
@@ -18,68 +17,48 @@ import org.elasticsearch.compute.data.IntVector;
 import org.elasticsearch.core.Releasables;
 
 /**
- * Aggregator state for an array of doubles.
+ * Aggregator state for an array of doubles. It is created in a mode where it
+ * won't track the {@code groupId}s that are sent to it and it is the
+ * responsibility of the caller to only fetch values for {@code groupId}s
+ * that it has sent using the {@code selected} parameter when building the
+ * results. This is fine when there are no {@code null} values in the input
+ * data. But once there are null values in the input data it is
+ * <strong>much</strong> more convenient to only send non-null values and
+ * the tracking built into the grouping code can't track that. In that case
+ * call {@link #enableGroupIdTracking} to transition the state into a mode
+ * where it'll track which {@code groupIds} have been written.
+ * <p>
  * This class is generated. Do not edit it.
+ * </p>
  */
-final class DoubleArrayState implements GroupingAggregatorState {
-    private final BigArrays bigArrays;
+final class DoubleArrayState extends AbstractArrayState implements GroupingAggregatorState {
     private final double init;
 
     private DoubleArray values;
-    /**
-     * Total number of groups {@code <=} values.length.
-     */
-    private int largestIndex;
-    private BitArray nonNulls;
 
     DoubleArrayState(BigArrays bigArrays, double init) {
-        this.bigArrays = bigArrays;
+        super(bigArrays);
         this.values = bigArrays.newDoubleArray(1, false);
         this.values.set(0, init);
         this.init = init;
     }
 
-    double get(int index) {
-        return values.get(index);
+    double get(int groupId) {
+        return values.get(groupId);
     }
 
-    double getOrDefault(int index) {
-        return index <= largestIndex ? values.get(index) : init;
+    double getOrDefault(int groupId) {
+        return groupId < values.size() ? values.get(groupId) : init;
     }
 
-    void set(double value, int index) {
-        if (index > largestIndex) {
-            ensureCapacity(index);
-            largestIndex = index;
-        }
-        values.set(index, value);
-        if (nonNulls != null) {
-            nonNulls.set(index);
-        }
-    }
-
-    void putNull(int index) {
-        if (index > largestIndex) {
-            ensureCapacity(index);
-            largestIndex = index;
-        }
-        if (nonNulls == null) {
-            nonNulls = new BitArray(index + 1, bigArrays);
-            for (int i = 0; i < index; i++) {
-                nonNulls.set(i);
-            }
-        } else {
-            // Do nothing. Null is represented by the default value of false for get(int),
-            // and any present value trumps a null value in our aggregations.
-        }
-    }
-
-    boolean hasValue(int index) {
-        return nonNulls == null || nonNulls.get(index);
+    void set(int groupId, double value) {
+        ensureCapacity(groupId);
+        values.set(groupId, value);
+        trackGroupId(groupId);
     }
 
     Block toValuesBlock(org.elasticsearch.compute.data.IntVector selected) {
-        if (nonNulls == null) {
+        if (false == trackingGroupIds()) {
             DoubleVector.Builder builder = DoubleVector.newVectorBuilder(selected.getPositionCount());
             for (int i = 0; i < selected.getPositionCount(); i++) {
                 builder.appendDouble(values.get(selected.getInt(i)));
@@ -98,10 +77,10 @@ final class DoubleArrayState implements GroupingAggregatorState {
         return builder.build();
     }
 
-    private void ensureCapacity(int position) {
-        if (position >= values.size()) {
+    private void ensureCapacity(int groupId) {
+        if (groupId >= values.size()) {
             long prevSize = values.size();
-            values = bigArrays.grow(values, position + 1);
+            values = bigArrays.grow(values, groupId + 1);
             values.fill(prevSize, values.size(), init);
         }
     }
@@ -111,18 +90,22 @@ final class DoubleArrayState implements GroupingAggregatorState {
     public void toIntermediate(Block[] blocks, int offset, IntVector selected) {
         assert blocks.length >= offset + 2;
         var valuesBuilder = DoubleBlock.newBlockBuilder(selected.getPositionCount());
-        var nullsBuilder = BooleanBlock.newBlockBuilder(selected.getPositionCount());
+        var hasValueBuilder = BooleanBlock.newBlockBuilder(selected.getPositionCount());
         for (int i = 0; i < selected.getPositionCount(); i++) {
             int group = selected.getInt(i);
-            valuesBuilder.appendDouble(values.get(group));
-            nullsBuilder.appendBoolean(hasValue(group));
+            if (group < values.size()) {
+                valuesBuilder.appendDouble(values.get(group));
+            } else {
+                valuesBuilder.appendDouble(0); // TODO can we just use null?
+            }
+            hasValueBuilder.appendBoolean(hasValue(group));
         }
         blocks[offset + 0] = valuesBuilder.build();
-        blocks[offset + 1] = nullsBuilder.build();
+        blocks[offset + 1] = hasValueBuilder.build();
     }
 
     @Override
     public void close() {
-        Releasables.close(values, nonNulls);
+        Releasables.close(values, super::close);
     }
 }

+ 35 - 52
x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/IntArrayState.java

@@ -8,7 +8,6 @@
 package org.elasticsearch.compute.aggregation;
 
 import org.elasticsearch.common.util.BigArrays;
-import org.elasticsearch.common.util.BitArray;
 import org.elasticsearch.common.util.IntArray;
 import org.elasticsearch.compute.data.Block;
 import org.elasticsearch.compute.data.BooleanBlock;
@@ -17,68 +16,48 @@ import org.elasticsearch.compute.data.IntVector;
 import org.elasticsearch.core.Releasables;
 
 /**
- * Aggregator state for an array of ints.
+ * Aggregator state for an array of ints. It is created in a mode where it
+ * won't track the {@code groupId}s that are sent to it and it is the
+ * responsibility of the caller to only fetch values for {@code groupId}s
+ * that it has sent using the {@code selected} parameter when building the
+ * results. This is fine when there are no {@code null} values in the input
+ * data. But once there are null values in the input data it is
+ * <strong>much</strong> more convenient to only send non-null values and
+ * the tracking built into the grouping code can't track that. In that case
+ * call {@link #enableGroupIdTracking} to transition the state into a mode
+ * where it'll track which {@code groupIds} have been written.
+ * <p>
  * This class is generated. Do not edit it.
+ * </p>
  */
-final class IntArrayState implements GroupingAggregatorState {
-    private final BigArrays bigArrays;
+final class IntArrayState extends AbstractArrayState implements GroupingAggregatorState {
     private final int init;
 
     private IntArray values;
-    /**
-     * Total number of groups {@code <=} values.length.
-     */
-    private int largestIndex;
-    private BitArray nonNulls;
 
     IntArrayState(BigArrays bigArrays, int init) {
-        this.bigArrays = bigArrays;
+        super(bigArrays);
         this.values = bigArrays.newIntArray(1, false);
         this.values.set(0, init);
         this.init = init;
     }
 
-    int get(int index) {
-        return values.get(index);
+    int get(int groupId) {
+        return values.get(groupId);
     }
 
-    int getOrDefault(int index) {
-        return index <= largestIndex ? values.get(index) : init;
+    int getOrDefault(int groupId) {
+        return groupId < values.size() ? values.get(groupId) : init;
     }
 
-    void set(int value, int index) {
-        if (index > largestIndex) {
-            ensureCapacity(index);
-            largestIndex = index;
-        }
-        values.set(index, value);
-        if (nonNulls != null) {
-            nonNulls.set(index);
-        }
-    }
-
-    void putNull(int index) {
-        if (index > largestIndex) {
-            ensureCapacity(index);
-            largestIndex = index;
-        }
-        if (nonNulls == null) {
-            nonNulls = new BitArray(index + 1, bigArrays);
-            for (int i = 0; i < index; i++) {
-                nonNulls.set(i);
-            }
-        } else {
-            // Do nothing. Null is represented by the default value of false for get(int),
-            // and any present value trumps a null value in our aggregations.
-        }
-    }
-
-    boolean hasValue(int index) {
-        return nonNulls == null || nonNulls.get(index);
+    void set(int groupId, int value) {
+        ensureCapacity(groupId);
+        values.set(groupId, value);
+        trackGroupId(groupId);
     }
 
     Block toValuesBlock(org.elasticsearch.compute.data.IntVector selected) {
-        if (nonNulls == null) {
+        if (false == trackingGroupIds()) {
             IntVector.Builder builder = IntVector.newVectorBuilder(selected.getPositionCount());
             for (int i = 0; i < selected.getPositionCount(); i++) {
                 builder.appendInt(values.get(selected.getInt(i)));
@@ -97,10 +76,10 @@ final class IntArrayState implements GroupingAggregatorState {
         return builder.build();
     }
 
-    private void ensureCapacity(int position) {
-        if (position >= values.size()) {
+    private void ensureCapacity(int groupId) {
+        if (groupId >= values.size()) {
             long prevSize = values.size();
-            values = bigArrays.grow(values, position + 1);
+            values = bigArrays.grow(values, groupId + 1);
             values.fill(prevSize, values.size(), init);
         }
     }
@@ -110,18 +89,22 @@ final class IntArrayState implements GroupingAggregatorState {
     public void toIntermediate(Block[] blocks, int offset, IntVector selected) {
         assert blocks.length >= offset + 2;
         var valuesBuilder = IntBlock.newBlockBuilder(selected.getPositionCount());
-        var nullsBuilder = BooleanBlock.newBlockBuilder(selected.getPositionCount());
+        var hasValueBuilder = BooleanBlock.newBlockBuilder(selected.getPositionCount());
         for (int i = 0; i < selected.getPositionCount(); i++) {
             int group = selected.getInt(i);
-            valuesBuilder.appendInt(values.get(group));
-            nullsBuilder.appendBoolean(hasValue(group));
+            if (group < values.size()) {
+                valuesBuilder.appendInt(values.get(group));
+            } else {
+                valuesBuilder.appendInt(0); // TODO can we just use null?
+            }
+            hasValueBuilder.appendBoolean(hasValue(group));
         }
         blocks[offset + 0] = valuesBuilder.build();
-        blocks[offset + 1] = nullsBuilder.build();
+        blocks[offset + 1] = hasValueBuilder.build();
     }
 
     @Override
     public void close() {
-        Releasables.close(values, nonNulls);
+        Releasables.close(values, super::close);
     }
 }

+ 39 - 61
x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/LongArrayState.java

@@ -8,7 +8,6 @@
 package org.elasticsearch.compute.aggregation;
 
 import org.elasticsearch.common.util.BigArrays;
-import org.elasticsearch.common.util.BitArray;
 import org.elasticsearch.common.util.LongArray;
 import org.elasticsearch.compute.data.Block;
 import org.elasticsearch.compute.data.BooleanBlock;
@@ -18,79 +17,54 @@ import org.elasticsearch.compute.data.LongVector;
 import org.elasticsearch.core.Releasables;
 
 /**
- * Aggregator state for an array of longs.
+ * Aggregator state for an array of longs. It is created in a mode where it
+ * won't track the {@code groupId}s that are sent to it and it is the
+ * responsibility of the caller to only fetch values for {@code groupId}s
+ * that it has sent using the {@code selected} parameter when building the
+ * results. This is fine when there are no {@code null} values in the input
+ * data. But once there are null values in the input data it is
+ * <strong>much</strong> more convenient to only send non-null values and
+ * the tracking built into the grouping code can't track that. In that case
+ * call {@link #enableGroupIdTracking} to transition the state into a mode
+ * where it'll track which {@code groupIds} have been written.
+ * <p>
  * This class is generated. Do not edit it.
+ * </p>
  */
-final class LongArrayState implements GroupingAggregatorState {
-    private final BigArrays bigArrays;
+final class LongArrayState extends AbstractArrayState implements GroupingAggregatorState {
     private final long init;
 
     private LongArray values;
-    /**
-     * Total number of groups {@code <=} values.length.
-     */
-    private int largestIndex;
-    private BitArray nonNulls;
 
     LongArrayState(BigArrays bigArrays, long init) {
-        this.bigArrays = bigArrays;
+        super(bigArrays);
         this.values = bigArrays.newLongArray(1, false);
         this.values.set(0, init);
         this.init = init;
     }
 
-    long get(int index) {
-        return values.get(index);
+    long get(int groupId) {
+        return values.get(groupId);
     }
 
-    long getOrDefault(int index) {
-        return index <= largestIndex ? values.get(index) : init;
+    long getOrDefault(int groupId) {
+        return groupId < values.size() ? values.get(groupId) : init;
     }
 
-    void set(long value, int index) {
-        if (index > largestIndex) {
-            ensureCapacity(index);
-            largestIndex = index;
-        }
-        values.set(index, value);
-        if (nonNulls != null) {
-            nonNulls.set(index);
-        }
+    void set(int groupId, long value) {
+        ensureCapacity(groupId);
+        values.set(groupId, value);
+        trackGroupId(groupId);
     }
 
-    void increment(long value, int index) {
-        if (index > largestIndex) {
-            ensureCapacity(index);
-            largestIndex = index;
-        }
-        values.increment(index, value);
-        if (nonNulls != null) {
-            nonNulls.set(index);
-        }
-    }
-
-    void putNull(int index) {
-        if (index > largestIndex) {
-            ensureCapacity(index);
-            largestIndex = index;
-        }
-        if (nonNulls == null) {
-            nonNulls = new BitArray(index + 1, bigArrays);
-            for (int i = 0; i < index; i++) {
-                nonNulls.set(i);
-            }
-        } else {
-            // Do nothing. Null is represented by the default value of false for get(int),
-            // and any present value trumps a null value in our aggregations.
-        }
-    }
-
-    boolean hasValue(int index) {
-        return nonNulls == null || nonNulls.get(index);
+    void increment(int groupId, long value) {
+        ensureCapacity(groupId);
+        values.increment(groupId, value);
+        trackGroupId(groupId);
     }
 
     Block toValuesBlock(org.elasticsearch.compute.data.IntVector selected) {
-        if (nonNulls == null) {
+        if (false == trackingGroupIds()) {
             LongVector.Builder builder = LongVector.newVectorBuilder(selected.getPositionCount());
             for (int i = 0; i < selected.getPositionCount(); i++) {
                 builder.appendLong(values.get(selected.getInt(i)));
@@ -109,10 +83,10 @@ final class LongArrayState implements GroupingAggregatorState {
         return builder.build();
     }
 
-    private void ensureCapacity(int position) {
-        if (position >= values.size()) {
+    private void ensureCapacity(int groupId) {
+        if (groupId >= values.size()) {
             long prevSize = values.size();
-            values = bigArrays.grow(values, position + 1);
+            values = bigArrays.grow(values, groupId + 1);
             values.fill(prevSize, values.size(), init);
         }
     }
@@ -122,18 +96,22 @@ final class LongArrayState implements GroupingAggregatorState {
     public void toIntermediate(Block[] blocks, int offset, IntVector selected) {
         assert blocks.length >= offset + 2;
         var valuesBuilder = LongBlock.newBlockBuilder(selected.getPositionCount());
-        var nullsBuilder = BooleanBlock.newBlockBuilder(selected.getPositionCount());
+        var hasValueBuilder = BooleanBlock.newBlockBuilder(selected.getPositionCount());
         for (int i = 0; i < selected.getPositionCount(); i++) {
             int group = selected.getInt(i);
-            valuesBuilder.appendLong(values.get(group));
-            nullsBuilder.appendBoolean(hasValue(group));
+            if (group < values.size()) {
+                valuesBuilder.appendLong(values.get(group));
+            } else {
+                valuesBuilder.appendLong(0); // TODO can we just use null?
+            }
+            hasValueBuilder.appendBoolean(hasValue(group));
         }
         blocks[offset + 0] = valuesBuilder.build();
-        blocks[offset + 1] = nullsBuilder.build();
+        blocks[offset + 1] = hasValueBuilder.build();
     }
 
     @Override
     public void close() {
-        Releasables.close(values, nonNulls);
+        Releasables.close(values, super::close);
     }
 }

+ 8 - 28
x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctBooleanGroupingAggregatorFunction.java

@@ -56,24 +56,27 @@ public final class CountDistinctBooleanGroupingAggregatorFunction implements Gro
   }
 
   @Override
-  public GroupingAggregatorFunction.AddInput prepareProcessPage(Page page) {
+  public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds,
+      Page page) {
     Block uncastValuesBlock = page.getBlock(channels.get(0));
     if (uncastValuesBlock.areAllValuesNull()) {
+      state.enableGroupIdTracking(seenGroupIds);
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
 
         @Override
         public void add(int positionOffset, LongVector groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
       };
     }
     BooleanBlock valuesBlock = (BooleanBlock) uncastValuesBlock;
     BooleanVector valuesVector = valuesBlock.asVector();
     if (valuesVector == null) {
+      if (valuesBlock.mayHaveNulls()) {
+        state.enableGroupIdTracking(seenGroupIds);
+      }
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
@@ -103,7 +106,6 @@ public final class CountDistinctBooleanGroupingAggregatorFunction implements Gro
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       int groupId = Math.toIntExact(groups.getLong(groupPosition));
       if (values.isNull(groupPosition + positionOffset)) {
-        state.putNull(groupId);
         continue;
       }
       int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
@@ -121,14 +123,6 @@ public final class CountDistinctBooleanGroupingAggregatorFunction implements Gro
     }
   }
 
-  private void addRawInputAllNulls(int positionOffset, LongVector groups, Block values) {
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      int groupId = Math.toIntExact(groups.getLong(groupPosition));
-      assert values.isNull(groupPosition + positionOffset);
-      state.putNull(groupPosition + positionOffset);
-    }
-  }
-
   private void addRawInput(int positionOffset, LongBlock groups, BooleanBlock values) {
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       if (groups.isNull(groupPosition)) {
@@ -139,7 +133,6 @@ public final class CountDistinctBooleanGroupingAggregatorFunction implements Gro
       for (int g = groupStart; g < groupEnd; g++) {
         int groupId = Math.toIntExact(groups.getLong(g));
         if (values.isNull(groupPosition + positionOffset)) {
-          state.putNull(groupId);
           continue;
         }
         int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
@@ -165,23 +158,9 @@ public final class CountDistinctBooleanGroupingAggregatorFunction implements Gro
     }
   }
 
-  private void addRawInputAllNulls(int positionOffset, LongBlock groups, Block values) {
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      if (groups.isNull(groupPosition)) {
-        continue;
-      }
-      int groupStart = groups.getFirstValueIndex(groupPosition);
-      int groupEnd = groupStart + groups.getValueCount(groupPosition);
-      for (int g = groupStart; g < groupEnd; g++) {
-        int groupId = Math.toIntExact(groups.getLong(g));
-        assert values.isNull(groupPosition + positionOffset);
-        state.putNull(groupPosition + positionOffset);
-      }
-    }
-  }
-
   @Override
   public void addIntermediateInput(int positionOffset, LongVector groups, Page page) {
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     assert channels.size() == intermediateBlockCount();
     BooleanVector fbit = page.<BooleanBlock>getBlock(channels.get(0)).asVector();
     BooleanVector tbit = page.<BooleanBlock>getBlock(channels.get(1)).asVector();
@@ -198,6 +177,7 @@ public final class CountDistinctBooleanGroupingAggregatorFunction implements Gro
       throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass());
     }
     CountDistinctBooleanAggregator.GroupingState inState = ((CountDistinctBooleanGroupingAggregatorFunction) input).state;
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     CountDistinctBooleanAggregator.combineStates(state, groupId, inState, position);
   }
 

+ 8 - 30
x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctBytesRefGroupingAggregatorFunction.java

@@ -59,24 +59,27 @@ public final class CountDistinctBytesRefGroupingAggregatorFunction implements Gr
   }
 
   @Override
-  public GroupingAggregatorFunction.AddInput prepareProcessPage(Page page) {
+  public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds,
+      Page page) {
     Block uncastValuesBlock = page.getBlock(channels.get(0));
     if (uncastValuesBlock.areAllValuesNull()) {
+      state.enableGroupIdTracking(seenGroupIds);
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
 
         @Override
         public void add(int positionOffset, LongVector groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
       };
     }
     BytesRefBlock valuesBlock = (BytesRefBlock) uncastValuesBlock;
     BytesRefVector valuesVector = valuesBlock.asVector();
     if (valuesVector == null) {
+      if (valuesBlock.mayHaveNulls()) {
+        state.enableGroupIdTracking(seenGroupIds);
+      }
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
@@ -107,7 +110,6 @@ public final class CountDistinctBytesRefGroupingAggregatorFunction implements Gr
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       int groupId = Math.toIntExact(groups.getLong(groupPosition));
       if (values.isNull(groupPosition + positionOffset)) {
-        state.putNull(groupId);
         continue;
       }
       int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
@@ -126,15 +128,6 @@ public final class CountDistinctBytesRefGroupingAggregatorFunction implements Gr
     }
   }
 
-  private void addRawInputAllNulls(int positionOffset, LongVector groups, Block values) {
-    BytesRef scratch = new BytesRef();
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      int groupId = Math.toIntExact(groups.getLong(groupPosition));
-      assert values.isNull(groupPosition + positionOffset);
-      state.putNull(groupPosition + positionOffset);
-    }
-  }
-
   private void addRawInput(int positionOffset, LongBlock groups, BytesRefBlock values) {
     BytesRef scratch = new BytesRef();
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
@@ -146,7 +139,6 @@ public final class CountDistinctBytesRefGroupingAggregatorFunction implements Gr
       for (int g = groupStart; g < groupEnd; g++) {
         int groupId = Math.toIntExact(groups.getLong(g));
         if (values.isNull(groupPosition + positionOffset)) {
-          state.putNull(groupId);
           continue;
         }
         int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
@@ -173,24 +165,9 @@ public final class CountDistinctBytesRefGroupingAggregatorFunction implements Gr
     }
   }
 
-  private void addRawInputAllNulls(int positionOffset, LongBlock groups, Block values) {
-    BytesRef scratch = new BytesRef();
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      if (groups.isNull(groupPosition)) {
-        continue;
-      }
-      int groupStart = groups.getFirstValueIndex(groupPosition);
-      int groupEnd = groupStart + groups.getValueCount(groupPosition);
-      for (int g = groupStart; g < groupEnd; g++) {
-        int groupId = Math.toIntExact(groups.getLong(g));
-        assert values.isNull(groupPosition + positionOffset);
-        state.putNull(groupPosition + positionOffset);
-      }
-    }
-  }
-
   @Override
   public void addIntermediateInput(int positionOffset, LongVector groups, Page page) {
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     assert channels.size() == intermediateBlockCount();
     BytesRefVector hll = page.<BytesRefBlock>getBlock(channels.get(0)).asVector();
     BytesRef scratch = new BytesRef();
@@ -206,6 +183,7 @@ public final class CountDistinctBytesRefGroupingAggregatorFunction implements Gr
       throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass());
     }
     HllStates.GroupingState inState = ((CountDistinctBytesRefGroupingAggregatorFunction) input).state;
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     CountDistinctBytesRefAggregator.combineStates(state, groupId, inState, position);
   }
 

+ 8 - 28
x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctDoubleGroupingAggregatorFunction.java

@@ -61,24 +61,27 @@ public final class CountDistinctDoubleGroupingAggregatorFunction implements Grou
   }
 
   @Override
-  public GroupingAggregatorFunction.AddInput prepareProcessPage(Page page) {
+  public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds,
+      Page page) {
     Block uncastValuesBlock = page.getBlock(channels.get(0));
     if (uncastValuesBlock.areAllValuesNull()) {
+      state.enableGroupIdTracking(seenGroupIds);
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
 
         @Override
         public void add(int positionOffset, LongVector groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
       };
     }
     DoubleBlock valuesBlock = (DoubleBlock) uncastValuesBlock;
     DoubleVector valuesVector = valuesBlock.asVector();
     if (valuesVector == null) {
+      if (valuesBlock.mayHaveNulls()) {
+        state.enableGroupIdTracking(seenGroupIds);
+      }
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
@@ -108,7 +111,6 @@ public final class CountDistinctDoubleGroupingAggregatorFunction implements Grou
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       int groupId = Math.toIntExact(groups.getLong(groupPosition));
       if (values.isNull(groupPosition + positionOffset)) {
-        state.putNull(groupId);
         continue;
       }
       int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
@@ -126,14 +128,6 @@ public final class CountDistinctDoubleGroupingAggregatorFunction implements Grou
     }
   }
 
-  private void addRawInputAllNulls(int positionOffset, LongVector groups, Block values) {
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      int groupId = Math.toIntExact(groups.getLong(groupPosition));
-      assert values.isNull(groupPosition + positionOffset);
-      state.putNull(groupPosition + positionOffset);
-    }
-  }
-
   private void addRawInput(int positionOffset, LongBlock groups, DoubleBlock values) {
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       if (groups.isNull(groupPosition)) {
@@ -144,7 +138,6 @@ public final class CountDistinctDoubleGroupingAggregatorFunction implements Grou
       for (int g = groupStart; g < groupEnd; g++) {
         int groupId = Math.toIntExact(groups.getLong(g));
         if (values.isNull(groupPosition + positionOffset)) {
-          state.putNull(groupId);
           continue;
         }
         int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
@@ -170,23 +163,9 @@ public final class CountDistinctDoubleGroupingAggregatorFunction implements Grou
     }
   }
 
-  private void addRawInputAllNulls(int positionOffset, LongBlock groups, Block values) {
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      if (groups.isNull(groupPosition)) {
-        continue;
-      }
-      int groupStart = groups.getFirstValueIndex(groupPosition);
-      int groupEnd = groupStart + groups.getValueCount(groupPosition);
-      for (int g = groupStart; g < groupEnd; g++) {
-        int groupId = Math.toIntExact(groups.getLong(g));
-        assert values.isNull(groupPosition + positionOffset);
-        state.putNull(groupPosition + positionOffset);
-      }
-    }
-  }
-
   @Override
   public void addIntermediateInput(int positionOffset, LongVector groups, Page page) {
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     assert channels.size() == intermediateBlockCount();
     BytesRefVector hll = page.<BytesRefBlock>getBlock(channels.get(0)).asVector();
     BytesRef scratch = new BytesRef();
@@ -202,6 +181,7 @@ public final class CountDistinctDoubleGroupingAggregatorFunction implements Grou
       throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass());
     }
     HllStates.GroupingState inState = ((CountDistinctDoubleGroupingAggregatorFunction) input).state;
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     CountDistinctDoubleAggregator.combineStates(state, groupId, inState, position);
   }
 

+ 8 - 28
x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctIntGroupingAggregatorFunction.java

@@ -60,24 +60,27 @@ public final class CountDistinctIntGroupingAggregatorFunction implements Groupin
   }
 
   @Override
-  public GroupingAggregatorFunction.AddInput prepareProcessPage(Page page) {
+  public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds,
+      Page page) {
     Block uncastValuesBlock = page.getBlock(channels.get(0));
     if (uncastValuesBlock.areAllValuesNull()) {
+      state.enableGroupIdTracking(seenGroupIds);
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
 
         @Override
         public void add(int positionOffset, LongVector groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
       };
     }
     IntBlock valuesBlock = (IntBlock) uncastValuesBlock;
     IntVector valuesVector = valuesBlock.asVector();
     if (valuesVector == null) {
+      if (valuesBlock.mayHaveNulls()) {
+        state.enableGroupIdTracking(seenGroupIds);
+      }
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
@@ -107,7 +110,6 @@ public final class CountDistinctIntGroupingAggregatorFunction implements Groupin
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       int groupId = Math.toIntExact(groups.getLong(groupPosition));
       if (values.isNull(groupPosition + positionOffset)) {
-        state.putNull(groupId);
         continue;
       }
       int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
@@ -125,14 +127,6 @@ public final class CountDistinctIntGroupingAggregatorFunction implements Groupin
     }
   }
 
-  private void addRawInputAllNulls(int positionOffset, LongVector groups, Block values) {
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      int groupId = Math.toIntExact(groups.getLong(groupPosition));
-      assert values.isNull(groupPosition + positionOffset);
-      state.putNull(groupPosition + positionOffset);
-    }
-  }
-
   private void addRawInput(int positionOffset, LongBlock groups, IntBlock values) {
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       if (groups.isNull(groupPosition)) {
@@ -143,7 +137,6 @@ public final class CountDistinctIntGroupingAggregatorFunction implements Groupin
       for (int g = groupStart; g < groupEnd; g++) {
         int groupId = Math.toIntExact(groups.getLong(g));
         if (values.isNull(groupPosition + positionOffset)) {
-          state.putNull(groupId);
           continue;
         }
         int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
@@ -169,23 +162,9 @@ public final class CountDistinctIntGroupingAggregatorFunction implements Groupin
     }
   }
 
-  private void addRawInputAllNulls(int positionOffset, LongBlock groups, Block values) {
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      if (groups.isNull(groupPosition)) {
-        continue;
-      }
-      int groupStart = groups.getFirstValueIndex(groupPosition);
-      int groupEnd = groupStart + groups.getValueCount(groupPosition);
-      for (int g = groupStart; g < groupEnd; g++) {
-        int groupId = Math.toIntExact(groups.getLong(g));
-        assert values.isNull(groupPosition + positionOffset);
-        state.putNull(groupPosition + positionOffset);
-      }
-    }
-  }
-
   @Override
   public void addIntermediateInput(int positionOffset, LongVector groups, Page page) {
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     assert channels.size() == intermediateBlockCount();
     BytesRefVector hll = page.<BytesRefBlock>getBlock(channels.get(0)).asVector();
     BytesRef scratch = new BytesRef();
@@ -201,6 +180,7 @@ public final class CountDistinctIntGroupingAggregatorFunction implements Groupin
       throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass());
     }
     HllStates.GroupingState inState = ((CountDistinctIntGroupingAggregatorFunction) input).state;
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     CountDistinctIntAggregator.combineStates(state, groupId, inState, position);
   }
 

+ 8 - 28
x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CountDistinctLongGroupingAggregatorFunction.java

@@ -59,24 +59,27 @@ public final class CountDistinctLongGroupingAggregatorFunction implements Groupi
   }
 
   @Override
-  public GroupingAggregatorFunction.AddInput prepareProcessPage(Page page) {
+  public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds,
+      Page page) {
     Block uncastValuesBlock = page.getBlock(channels.get(0));
     if (uncastValuesBlock.areAllValuesNull()) {
+      state.enableGroupIdTracking(seenGroupIds);
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
 
         @Override
         public void add(int positionOffset, LongVector groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
       };
     }
     LongBlock valuesBlock = (LongBlock) uncastValuesBlock;
     LongVector valuesVector = valuesBlock.asVector();
     if (valuesVector == null) {
+      if (valuesBlock.mayHaveNulls()) {
+        state.enableGroupIdTracking(seenGroupIds);
+      }
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
@@ -106,7 +109,6 @@ public final class CountDistinctLongGroupingAggregatorFunction implements Groupi
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       int groupId = Math.toIntExact(groups.getLong(groupPosition));
       if (values.isNull(groupPosition + positionOffset)) {
-        state.putNull(groupId);
         continue;
       }
       int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
@@ -124,14 +126,6 @@ public final class CountDistinctLongGroupingAggregatorFunction implements Groupi
     }
   }
 
-  private void addRawInputAllNulls(int positionOffset, LongVector groups, Block values) {
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      int groupId = Math.toIntExact(groups.getLong(groupPosition));
-      assert values.isNull(groupPosition + positionOffset);
-      state.putNull(groupPosition + positionOffset);
-    }
-  }
-
   private void addRawInput(int positionOffset, LongBlock groups, LongBlock values) {
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       if (groups.isNull(groupPosition)) {
@@ -142,7 +136,6 @@ public final class CountDistinctLongGroupingAggregatorFunction implements Groupi
       for (int g = groupStart; g < groupEnd; g++) {
         int groupId = Math.toIntExact(groups.getLong(g));
         if (values.isNull(groupPosition + positionOffset)) {
-          state.putNull(groupId);
           continue;
         }
         int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
@@ -168,23 +161,9 @@ public final class CountDistinctLongGroupingAggregatorFunction implements Groupi
     }
   }
 
-  private void addRawInputAllNulls(int positionOffset, LongBlock groups, Block values) {
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      if (groups.isNull(groupPosition)) {
-        continue;
-      }
-      int groupStart = groups.getFirstValueIndex(groupPosition);
-      int groupEnd = groupStart + groups.getValueCount(groupPosition);
-      for (int g = groupStart; g < groupEnd; g++) {
-        int groupId = Math.toIntExact(groups.getLong(g));
-        assert values.isNull(groupPosition + positionOffset);
-        state.putNull(groupPosition + positionOffset);
-      }
-    }
-  }
-
   @Override
   public void addIntermediateInput(int positionOffset, LongVector groups, Page page) {
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     assert channels.size() == intermediateBlockCount();
     BytesRefVector hll = page.<BytesRefBlock>getBlock(channels.get(0)).asVector();
     BytesRef scratch = new BytesRef();
@@ -200,6 +179,7 @@ public final class CountDistinctLongGroupingAggregatorFunction implements Groupi
       throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass());
     }
     HllStates.GroupingState inState = ((CountDistinctLongGroupingAggregatorFunction) input).state;
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     CountDistinctLongAggregator.combineStates(state, groupId, inState, position);
   }
 

+ 14 - 38
x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxDoubleGroupingAggregatorFunction.java

@@ -54,24 +54,27 @@ public final class MaxDoubleGroupingAggregatorFunction implements GroupingAggreg
   }
 
   @Override
-  public GroupingAggregatorFunction.AddInput prepareProcessPage(Page page) {
+  public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds,
+      Page page) {
     Block uncastValuesBlock = page.getBlock(channels.get(0));
     if (uncastValuesBlock.areAllValuesNull()) {
+      state.enableGroupIdTracking(seenGroupIds);
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
 
         @Override
         public void add(int positionOffset, LongVector groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
       };
     }
     DoubleBlock valuesBlock = (DoubleBlock) uncastValuesBlock;
     DoubleVector valuesVector = valuesBlock.asVector();
     if (valuesVector == null) {
+      if (valuesBlock.mayHaveNulls()) {
+        state.enableGroupIdTracking(seenGroupIds);
+      }
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
@@ -101,13 +104,12 @@ public final class MaxDoubleGroupingAggregatorFunction implements GroupingAggreg
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       int groupId = Math.toIntExact(groups.getLong(groupPosition));
       if (values.isNull(groupPosition + positionOffset)) {
-        state.putNull(groupId);
         continue;
       }
       int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
       int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
       for (int v = valuesStart; v < valuesEnd; v++) {
-        state.set(MaxDoubleAggregator.combine(state.getOrDefault(groupId), values.getDouble(v)), groupId);
+        state.set(groupId, MaxDoubleAggregator.combine(state.getOrDefault(groupId), values.getDouble(v)));
       }
     }
   }
@@ -115,15 +117,7 @@ public final class MaxDoubleGroupingAggregatorFunction implements GroupingAggreg
   private void addRawInput(int positionOffset, LongVector groups, DoubleVector values) {
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       int groupId = Math.toIntExact(groups.getLong(groupPosition));
-      state.set(MaxDoubleAggregator.combine(state.getOrDefault(groupId), values.getDouble(groupPosition + positionOffset)), groupId);
-    }
-  }
-
-  private void addRawInputAllNulls(int positionOffset, LongVector groups, Block values) {
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      int groupId = Math.toIntExact(groups.getLong(groupPosition));
-      assert values.isNull(groupPosition + positionOffset);
-      state.putNull(groupPosition + positionOffset);
+      state.set(groupId, MaxDoubleAggregator.combine(state.getOrDefault(groupId), values.getDouble(groupPosition + positionOffset)));
     }
   }
 
@@ -137,13 +131,12 @@ public final class MaxDoubleGroupingAggregatorFunction implements GroupingAggreg
       for (int g = groupStart; g < groupEnd; g++) {
         int groupId = Math.toIntExact(groups.getLong(g));
         if (values.isNull(groupPosition + positionOffset)) {
-          state.putNull(groupId);
           continue;
         }
         int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
         int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
         for (int v = valuesStart; v < valuesEnd; v++) {
-          state.set(MaxDoubleAggregator.combine(state.getOrDefault(groupId), values.getDouble(v)), groupId);
+          state.set(groupId, MaxDoubleAggregator.combine(state.getOrDefault(groupId), values.getDouble(v)));
         }
       }
     }
@@ -158,28 +151,14 @@ public final class MaxDoubleGroupingAggregatorFunction implements GroupingAggreg
       int groupEnd = groupStart + groups.getValueCount(groupPosition);
       for (int g = groupStart; g < groupEnd; g++) {
         int groupId = Math.toIntExact(groups.getLong(g));
-        state.set(MaxDoubleAggregator.combine(state.getOrDefault(groupId), values.getDouble(groupPosition + positionOffset)), groupId);
-      }
-    }
-  }
-
-  private void addRawInputAllNulls(int positionOffset, LongBlock groups, Block values) {
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      if (groups.isNull(groupPosition)) {
-        continue;
-      }
-      int groupStart = groups.getFirstValueIndex(groupPosition);
-      int groupEnd = groupStart + groups.getValueCount(groupPosition);
-      for (int g = groupStart; g < groupEnd; g++) {
-        int groupId = Math.toIntExact(groups.getLong(g));
-        assert values.isNull(groupPosition + positionOffset);
-        state.putNull(groupPosition + positionOffset);
+        state.set(groupId, MaxDoubleAggregator.combine(state.getOrDefault(groupId), values.getDouble(groupPosition + positionOffset)));
       }
     }
   }
 
   @Override
   public void addIntermediateInput(int positionOffset, LongVector groups, Page page) {
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     assert channels.size() == intermediateBlockCount();
     DoubleVector max = page.<DoubleBlock>getBlock(channels.get(0)).asVector();
     BooleanVector seen = page.<BooleanBlock>getBlock(channels.get(1)).asVector();
@@ -187,9 +166,7 @@ public final class MaxDoubleGroupingAggregatorFunction implements GroupingAggreg
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       int groupId = Math.toIntExact(groups.getLong(groupPosition));
       if (seen.getBoolean(groupPosition + positionOffset)) {
-        state.set(MaxDoubleAggregator.combine(state.getOrDefault(groupId), max.getDouble(groupPosition + positionOffset)), groupId);
-      } else {
-        state.putNull(groupId);
+        state.set(groupId, MaxDoubleAggregator.combine(max.getDouble(groupPosition + positionOffset), state.getOrDefault(groupId)));
       }
     }
   }
@@ -200,10 +177,9 @@ public final class MaxDoubleGroupingAggregatorFunction implements GroupingAggreg
       throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass());
     }
     DoubleArrayState inState = ((MaxDoubleGroupingAggregatorFunction) input).state;
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     if (inState.hasValue(position)) {
-      state.set(MaxDoubleAggregator.combine(state.getOrDefault(groupId), inState.get(position)), groupId);
-    } else {
-      state.putNull(groupId);
+      state.set(groupId, MaxDoubleAggregator.combine(state.getOrDefault(groupId), inState.get(position)));
     }
   }
 

+ 14 - 38
x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxIntGroupingAggregatorFunction.java

@@ -53,24 +53,27 @@ public final class MaxIntGroupingAggregatorFunction implements GroupingAggregato
   }
 
   @Override
-  public GroupingAggregatorFunction.AddInput prepareProcessPage(Page page) {
+  public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds,
+      Page page) {
     Block uncastValuesBlock = page.getBlock(channels.get(0));
     if (uncastValuesBlock.areAllValuesNull()) {
+      state.enableGroupIdTracking(seenGroupIds);
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
 
         @Override
         public void add(int positionOffset, LongVector groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
       };
     }
     IntBlock valuesBlock = (IntBlock) uncastValuesBlock;
     IntVector valuesVector = valuesBlock.asVector();
     if (valuesVector == null) {
+      if (valuesBlock.mayHaveNulls()) {
+        state.enableGroupIdTracking(seenGroupIds);
+      }
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
@@ -100,13 +103,12 @@ public final class MaxIntGroupingAggregatorFunction implements GroupingAggregato
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       int groupId = Math.toIntExact(groups.getLong(groupPosition));
       if (values.isNull(groupPosition + positionOffset)) {
-        state.putNull(groupId);
         continue;
       }
       int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
       int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
       for (int v = valuesStart; v < valuesEnd; v++) {
-        state.set(MaxIntAggregator.combine(state.getOrDefault(groupId), values.getInt(v)), groupId);
+        state.set(groupId, MaxIntAggregator.combine(state.getOrDefault(groupId), values.getInt(v)));
       }
     }
   }
@@ -114,15 +116,7 @@ public final class MaxIntGroupingAggregatorFunction implements GroupingAggregato
   private void addRawInput(int positionOffset, LongVector groups, IntVector values) {
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       int groupId = Math.toIntExact(groups.getLong(groupPosition));
-      state.set(MaxIntAggregator.combine(state.getOrDefault(groupId), values.getInt(groupPosition + positionOffset)), groupId);
-    }
-  }
-
-  private void addRawInputAllNulls(int positionOffset, LongVector groups, Block values) {
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      int groupId = Math.toIntExact(groups.getLong(groupPosition));
-      assert values.isNull(groupPosition + positionOffset);
-      state.putNull(groupPosition + positionOffset);
+      state.set(groupId, MaxIntAggregator.combine(state.getOrDefault(groupId), values.getInt(groupPosition + positionOffset)));
     }
   }
 
@@ -136,13 +130,12 @@ public final class MaxIntGroupingAggregatorFunction implements GroupingAggregato
       for (int g = groupStart; g < groupEnd; g++) {
         int groupId = Math.toIntExact(groups.getLong(g));
         if (values.isNull(groupPosition + positionOffset)) {
-          state.putNull(groupId);
           continue;
         }
         int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
         int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
         for (int v = valuesStart; v < valuesEnd; v++) {
-          state.set(MaxIntAggregator.combine(state.getOrDefault(groupId), values.getInt(v)), groupId);
+          state.set(groupId, MaxIntAggregator.combine(state.getOrDefault(groupId), values.getInt(v)));
         }
       }
     }
@@ -157,28 +150,14 @@ public final class MaxIntGroupingAggregatorFunction implements GroupingAggregato
       int groupEnd = groupStart + groups.getValueCount(groupPosition);
       for (int g = groupStart; g < groupEnd; g++) {
         int groupId = Math.toIntExact(groups.getLong(g));
-        state.set(MaxIntAggregator.combine(state.getOrDefault(groupId), values.getInt(groupPosition + positionOffset)), groupId);
-      }
-    }
-  }
-
-  private void addRawInputAllNulls(int positionOffset, LongBlock groups, Block values) {
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      if (groups.isNull(groupPosition)) {
-        continue;
-      }
-      int groupStart = groups.getFirstValueIndex(groupPosition);
-      int groupEnd = groupStart + groups.getValueCount(groupPosition);
-      for (int g = groupStart; g < groupEnd; g++) {
-        int groupId = Math.toIntExact(groups.getLong(g));
-        assert values.isNull(groupPosition + positionOffset);
-        state.putNull(groupPosition + positionOffset);
+        state.set(groupId, MaxIntAggregator.combine(state.getOrDefault(groupId), values.getInt(groupPosition + positionOffset)));
       }
     }
   }
 
   @Override
   public void addIntermediateInput(int positionOffset, LongVector groups, Page page) {
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     assert channels.size() == intermediateBlockCount();
     IntVector max = page.<IntBlock>getBlock(channels.get(0)).asVector();
     BooleanVector seen = page.<BooleanBlock>getBlock(channels.get(1)).asVector();
@@ -186,9 +165,7 @@ public final class MaxIntGroupingAggregatorFunction implements GroupingAggregato
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       int groupId = Math.toIntExact(groups.getLong(groupPosition));
       if (seen.getBoolean(groupPosition + positionOffset)) {
-        state.set(MaxIntAggregator.combine(state.getOrDefault(groupId), max.getInt(groupPosition + positionOffset)), groupId);
-      } else {
-        state.putNull(groupId);
+        state.set(groupId, MaxIntAggregator.combine(max.getInt(groupPosition + positionOffset), state.getOrDefault(groupId)));
       }
     }
   }
@@ -199,10 +176,9 @@ public final class MaxIntGroupingAggregatorFunction implements GroupingAggregato
       throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass());
     }
     IntArrayState inState = ((MaxIntGroupingAggregatorFunction) input).state;
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     if (inState.hasValue(position)) {
-      state.set(MaxIntAggregator.combine(state.getOrDefault(groupId), inState.get(position)), groupId);
-    } else {
-      state.putNull(groupId);
+      state.set(groupId, MaxIntAggregator.combine(state.getOrDefault(groupId), inState.get(position)));
     }
   }
 

+ 14 - 38
x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxLongGroupingAggregatorFunction.java

@@ -52,24 +52,27 @@ public final class MaxLongGroupingAggregatorFunction implements GroupingAggregat
   }
 
   @Override
-  public GroupingAggregatorFunction.AddInput prepareProcessPage(Page page) {
+  public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds,
+      Page page) {
     Block uncastValuesBlock = page.getBlock(channels.get(0));
     if (uncastValuesBlock.areAllValuesNull()) {
+      state.enableGroupIdTracking(seenGroupIds);
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
 
         @Override
         public void add(int positionOffset, LongVector groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
       };
     }
     LongBlock valuesBlock = (LongBlock) uncastValuesBlock;
     LongVector valuesVector = valuesBlock.asVector();
     if (valuesVector == null) {
+      if (valuesBlock.mayHaveNulls()) {
+        state.enableGroupIdTracking(seenGroupIds);
+      }
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
@@ -99,13 +102,12 @@ public final class MaxLongGroupingAggregatorFunction implements GroupingAggregat
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       int groupId = Math.toIntExact(groups.getLong(groupPosition));
       if (values.isNull(groupPosition + positionOffset)) {
-        state.putNull(groupId);
         continue;
       }
       int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
       int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
       for (int v = valuesStart; v < valuesEnd; v++) {
-        state.set(MaxLongAggregator.combine(state.getOrDefault(groupId), values.getLong(v)), groupId);
+        state.set(groupId, MaxLongAggregator.combine(state.getOrDefault(groupId), values.getLong(v)));
       }
     }
   }
@@ -113,15 +115,7 @@ public final class MaxLongGroupingAggregatorFunction implements GroupingAggregat
   private void addRawInput(int positionOffset, LongVector groups, LongVector values) {
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       int groupId = Math.toIntExact(groups.getLong(groupPosition));
-      state.set(MaxLongAggregator.combine(state.getOrDefault(groupId), values.getLong(groupPosition + positionOffset)), groupId);
-    }
-  }
-
-  private void addRawInputAllNulls(int positionOffset, LongVector groups, Block values) {
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      int groupId = Math.toIntExact(groups.getLong(groupPosition));
-      assert values.isNull(groupPosition + positionOffset);
-      state.putNull(groupPosition + positionOffset);
+      state.set(groupId, MaxLongAggregator.combine(state.getOrDefault(groupId), values.getLong(groupPosition + positionOffset)));
     }
   }
 
@@ -135,13 +129,12 @@ public final class MaxLongGroupingAggregatorFunction implements GroupingAggregat
       for (int g = groupStart; g < groupEnd; g++) {
         int groupId = Math.toIntExact(groups.getLong(g));
         if (values.isNull(groupPosition + positionOffset)) {
-          state.putNull(groupId);
           continue;
         }
         int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
         int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
         for (int v = valuesStart; v < valuesEnd; v++) {
-          state.set(MaxLongAggregator.combine(state.getOrDefault(groupId), values.getLong(v)), groupId);
+          state.set(groupId, MaxLongAggregator.combine(state.getOrDefault(groupId), values.getLong(v)));
         }
       }
     }
@@ -156,28 +149,14 @@ public final class MaxLongGroupingAggregatorFunction implements GroupingAggregat
       int groupEnd = groupStart + groups.getValueCount(groupPosition);
       for (int g = groupStart; g < groupEnd; g++) {
         int groupId = Math.toIntExact(groups.getLong(g));
-        state.set(MaxLongAggregator.combine(state.getOrDefault(groupId), values.getLong(groupPosition + positionOffset)), groupId);
-      }
-    }
-  }
-
-  private void addRawInputAllNulls(int positionOffset, LongBlock groups, Block values) {
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      if (groups.isNull(groupPosition)) {
-        continue;
-      }
-      int groupStart = groups.getFirstValueIndex(groupPosition);
-      int groupEnd = groupStart + groups.getValueCount(groupPosition);
-      for (int g = groupStart; g < groupEnd; g++) {
-        int groupId = Math.toIntExact(groups.getLong(g));
-        assert values.isNull(groupPosition + positionOffset);
-        state.putNull(groupPosition + positionOffset);
+        state.set(groupId, MaxLongAggregator.combine(state.getOrDefault(groupId), values.getLong(groupPosition + positionOffset)));
       }
     }
   }
 
   @Override
   public void addIntermediateInput(int positionOffset, LongVector groups, Page page) {
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     assert channels.size() == intermediateBlockCount();
     LongVector max = page.<LongBlock>getBlock(channels.get(0)).asVector();
     BooleanVector seen = page.<BooleanBlock>getBlock(channels.get(1)).asVector();
@@ -185,9 +164,7 @@ public final class MaxLongGroupingAggregatorFunction implements GroupingAggregat
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       int groupId = Math.toIntExact(groups.getLong(groupPosition));
       if (seen.getBoolean(groupPosition + positionOffset)) {
-        state.set(MaxLongAggregator.combine(state.getOrDefault(groupId), max.getLong(groupPosition + positionOffset)), groupId);
-      } else {
-        state.putNull(groupId);
+        state.set(groupId, MaxLongAggregator.combine(max.getLong(groupPosition + positionOffset), state.getOrDefault(groupId)));
       }
     }
   }
@@ -198,10 +175,9 @@ public final class MaxLongGroupingAggregatorFunction implements GroupingAggregat
       throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass());
     }
     LongArrayState inState = ((MaxLongGroupingAggregatorFunction) input).state;
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     if (inState.hasValue(position)) {
-      state.set(MaxLongAggregator.combine(state.getOrDefault(groupId), inState.get(position)), groupId);
-    } else {
-      state.putNull(groupId);
+      state.set(groupId, MaxLongAggregator.combine(state.getOrDefault(groupId), inState.get(position)));
     }
   }
 

+ 8 - 28
x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationDoubleGroupingAggregatorFunction.java

@@ -58,24 +58,27 @@ public final class MedianAbsoluteDeviationDoubleGroupingAggregatorFunction imple
   }
 
   @Override
-  public GroupingAggregatorFunction.AddInput prepareProcessPage(Page page) {
+  public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds,
+      Page page) {
     Block uncastValuesBlock = page.getBlock(channels.get(0));
     if (uncastValuesBlock.areAllValuesNull()) {
+      state.enableGroupIdTracking(seenGroupIds);
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
 
         @Override
         public void add(int positionOffset, LongVector groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
       };
     }
     DoubleBlock valuesBlock = (DoubleBlock) uncastValuesBlock;
     DoubleVector valuesVector = valuesBlock.asVector();
     if (valuesVector == null) {
+      if (valuesBlock.mayHaveNulls()) {
+        state.enableGroupIdTracking(seenGroupIds);
+      }
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
@@ -105,7 +108,6 @@ public final class MedianAbsoluteDeviationDoubleGroupingAggregatorFunction imple
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       int groupId = Math.toIntExact(groups.getLong(groupPosition));
       if (values.isNull(groupPosition + positionOffset)) {
-        state.putNull(groupId);
         continue;
       }
       int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
@@ -123,14 +125,6 @@ public final class MedianAbsoluteDeviationDoubleGroupingAggregatorFunction imple
     }
   }
 
-  private void addRawInputAllNulls(int positionOffset, LongVector groups, Block values) {
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      int groupId = Math.toIntExact(groups.getLong(groupPosition));
-      assert values.isNull(groupPosition + positionOffset);
-      state.putNull(groupPosition + positionOffset);
-    }
-  }
-
   private void addRawInput(int positionOffset, LongBlock groups, DoubleBlock values) {
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       if (groups.isNull(groupPosition)) {
@@ -141,7 +135,6 @@ public final class MedianAbsoluteDeviationDoubleGroupingAggregatorFunction imple
       for (int g = groupStart; g < groupEnd; g++) {
         int groupId = Math.toIntExact(groups.getLong(g));
         if (values.isNull(groupPosition + positionOffset)) {
-          state.putNull(groupId);
           continue;
         }
         int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
@@ -167,23 +160,9 @@ public final class MedianAbsoluteDeviationDoubleGroupingAggregatorFunction imple
     }
   }
 
-  private void addRawInputAllNulls(int positionOffset, LongBlock groups, Block values) {
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      if (groups.isNull(groupPosition)) {
-        continue;
-      }
-      int groupStart = groups.getFirstValueIndex(groupPosition);
-      int groupEnd = groupStart + groups.getValueCount(groupPosition);
-      for (int g = groupStart; g < groupEnd; g++) {
-        int groupId = Math.toIntExact(groups.getLong(g));
-        assert values.isNull(groupPosition + positionOffset);
-        state.putNull(groupPosition + positionOffset);
-      }
-    }
-  }
-
   @Override
   public void addIntermediateInput(int positionOffset, LongVector groups, Page page) {
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     assert channels.size() == intermediateBlockCount();
     BytesRefVector quart = page.<BytesRefBlock>getBlock(channels.get(0)).asVector();
     BytesRef scratch = new BytesRef();
@@ -199,6 +178,7 @@ public final class MedianAbsoluteDeviationDoubleGroupingAggregatorFunction imple
       throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass());
     }
     QuantileStates.GroupingState inState = ((MedianAbsoluteDeviationDoubleGroupingAggregatorFunction) input).state;
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     MedianAbsoluteDeviationDoubleAggregator.combineStates(state, groupId, inState, position);
   }
 

+ 8 - 28
x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationIntGroupingAggregatorFunction.java

@@ -57,24 +57,27 @@ public final class MedianAbsoluteDeviationIntGroupingAggregatorFunction implemen
   }
 
   @Override
-  public GroupingAggregatorFunction.AddInput prepareProcessPage(Page page) {
+  public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds,
+      Page page) {
     Block uncastValuesBlock = page.getBlock(channels.get(0));
     if (uncastValuesBlock.areAllValuesNull()) {
+      state.enableGroupIdTracking(seenGroupIds);
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
 
         @Override
         public void add(int positionOffset, LongVector groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
       };
     }
     IntBlock valuesBlock = (IntBlock) uncastValuesBlock;
     IntVector valuesVector = valuesBlock.asVector();
     if (valuesVector == null) {
+      if (valuesBlock.mayHaveNulls()) {
+        state.enableGroupIdTracking(seenGroupIds);
+      }
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
@@ -104,7 +107,6 @@ public final class MedianAbsoluteDeviationIntGroupingAggregatorFunction implemen
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       int groupId = Math.toIntExact(groups.getLong(groupPosition));
       if (values.isNull(groupPosition + positionOffset)) {
-        state.putNull(groupId);
         continue;
       }
       int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
@@ -122,14 +124,6 @@ public final class MedianAbsoluteDeviationIntGroupingAggregatorFunction implemen
     }
   }
 
-  private void addRawInputAllNulls(int positionOffset, LongVector groups, Block values) {
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      int groupId = Math.toIntExact(groups.getLong(groupPosition));
-      assert values.isNull(groupPosition + positionOffset);
-      state.putNull(groupPosition + positionOffset);
-    }
-  }
-
   private void addRawInput(int positionOffset, LongBlock groups, IntBlock values) {
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       if (groups.isNull(groupPosition)) {
@@ -140,7 +134,6 @@ public final class MedianAbsoluteDeviationIntGroupingAggregatorFunction implemen
       for (int g = groupStart; g < groupEnd; g++) {
         int groupId = Math.toIntExact(groups.getLong(g));
         if (values.isNull(groupPosition + positionOffset)) {
-          state.putNull(groupId);
           continue;
         }
         int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
@@ -166,23 +159,9 @@ public final class MedianAbsoluteDeviationIntGroupingAggregatorFunction implemen
     }
   }
 
-  private void addRawInputAllNulls(int positionOffset, LongBlock groups, Block values) {
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      if (groups.isNull(groupPosition)) {
-        continue;
-      }
-      int groupStart = groups.getFirstValueIndex(groupPosition);
-      int groupEnd = groupStart + groups.getValueCount(groupPosition);
-      for (int g = groupStart; g < groupEnd; g++) {
-        int groupId = Math.toIntExact(groups.getLong(g));
-        assert values.isNull(groupPosition + positionOffset);
-        state.putNull(groupPosition + positionOffset);
-      }
-    }
-  }
-
   @Override
   public void addIntermediateInput(int positionOffset, LongVector groups, Page page) {
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     assert channels.size() == intermediateBlockCount();
     BytesRefVector quart = page.<BytesRefBlock>getBlock(channels.get(0)).asVector();
     BytesRef scratch = new BytesRef();
@@ -198,6 +177,7 @@ public final class MedianAbsoluteDeviationIntGroupingAggregatorFunction implemen
       throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass());
     }
     QuantileStates.GroupingState inState = ((MedianAbsoluteDeviationIntGroupingAggregatorFunction) input).state;
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     MedianAbsoluteDeviationIntAggregator.combineStates(state, groupId, inState, position);
   }
 

+ 8 - 28
x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MedianAbsoluteDeviationLongGroupingAggregatorFunction.java

@@ -56,24 +56,27 @@ public final class MedianAbsoluteDeviationLongGroupingAggregatorFunction impleme
   }
 
   @Override
-  public GroupingAggregatorFunction.AddInput prepareProcessPage(Page page) {
+  public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds,
+      Page page) {
     Block uncastValuesBlock = page.getBlock(channels.get(0));
     if (uncastValuesBlock.areAllValuesNull()) {
+      state.enableGroupIdTracking(seenGroupIds);
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
 
         @Override
         public void add(int positionOffset, LongVector groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
       };
     }
     LongBlock valuesBlock = (LongBlock) uncastValuesBlock;
     LongVector valuesVector = valuesBlock.asVector();
     if (valuesVector == null) {
+      if (valuesBlock.mayHaveNulls()) {
+        state.enableGroupIdTracking(seenGroupIds);
+      }
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
@@ -103,7 +106,6 @@ public final class MedianAbsoluteDeviationLongGroupingAggregatorFunction impleme
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       int groupId = Math.toIntExact(groups.getLong(groupPosition));
       if (values.isNull(groupPosition + positionOffset)) {
-        state.putNull(groupId);
         continue;
       }
       int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
@@ -121,14 +123,6 @@ public final class MedianAbsoluteDeviationLongGroupingAggregatorFunction impleme
     }
   }
 
-  private void addRawInputAllNulls(int positionOffset, LongVector groups, Block values) {
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      int groupId = Math.toIntExact(groups.getLong(groupPosition));
-      assert values.isNull(groupPosition + positionOffset);
-      state.putNull(groupPosition + positionOffset);
-    }
-  }
-
   private void addRawInput(int positionOffset, LongBlock groups, LongBlock values) {
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       if (groups.isNull(groupPosition)) {
@@ -139,7 +133,6 @@ public final class MedianAbsoluteDeviationLongGroupingAggregatorFunction impleme
       for (int g = groupStart; g < groupEnd; g++) {
         int groupId = Math.toIntExact(groups.getLong(g));
         if (values.isNull(groupPosition + positionOffset)) {
-          state.putNull(groupId);
           continue;
         }
         int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
@@ -165,23 +158,9 @@ public final class MedianAbsoluteDeviationLongGroupingAggregatorFunction impleme
     }
   }
 
-  private void addRawInputAllNulls(int positionOffset, LongBlock groups, Block values) {
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      if (groups.isNull(groupPosition)) {
-        continue;
-      }
-      int groupStart = groups.getFirstValueIndex(groupPosition);
-      int groupEnd = groupStart + groups.getValueCount(groupPosition);
-      for (int g = groupStart; g < groupEnd; g++) {
-        int groupId = Math.toIntExact(groups.getLong(g));
-        assert values.isNull(groupPosition + positionOffset);
-        state.putNull(groupPosition + positionOffset);
-      }
-    }
-  }
-
   @Override
   public void addIntermediateInput(int positionOffset, LongVector groups, Page page) {
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     assert channels.size() == intermediateBlockCount();
     BytesRefVector quart = page.<BytesRefBlock>getBlock(channels.get(0)).asVector();
     BytesRef scratch = new BytesRef();
@@ -197,6 +176,7 @@ public final class MedianAbsoluteDeviationLongGroupingAggregatorFunction impleme
       throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass());
     }
     QuantileStates.GroupingState inState = ((MedianAbsoluteDeviationLongGroupingAggregatorFunction) input).state;
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     MedianAbsoluteDeviationLongAggregator.combineStates(state, groupId, inState, position);
   }
 

+ 14 - 38
x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinDoubleGroupingAggregatorFunction.java

@@ -54,24 +54,27 @@ public final class MinDoubleGroupingAggregatorFunction implements GroupingAggreg
   }
 
   @Override
-  public GroupingAggregatorFunction.AddInput prepareProcessPage(Page page) {
+  public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds,
+      Page page) {
     Block uncastValuesBlock = page.getBlock(channels.get(0));
     if (uncastValuesBlock.areAllValuesNull()) {
+      state.enableGroupIdTracking(seenGroupIds);
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
 
         @Override
         public void add(int positionOffset, LongVector groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
       };
     }
     DoubleBlock valuesBlock = (DoubleBlock) uncastValuesBlock;
     DoubleVector valuesVector = valuesBlock.asVector();
     if (valuesVector == null) {
+      if (valuesBlock.mayHaveNulls()) {
+        state.enableGroupIdTracking(seenGroupIds);
+      }
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
@@ -101,13 +104,12 @@ public final class MinDoubleGroupingAggregatorFunction implements GroupingAggreg
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       int groupId = Math.toIntExact(groups.getLong(groupPosition));
       if (values.isNull(groupPosition + positionOffset)) {
-        state.putNull(groupId);
         continue;
       }
       int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
       int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
       for (int v = valuesStart; v < valuesEnd; v++) {
-        state.set(MinDoubleAggregator.combine(state.getOrDefault(groupId), values.getDouble(v)), groupId);
+        state.set(groupId, MinDoubleAggregator.combine(state.getOrDefault(groupId), values.getDouble(v)));
       }
     }
   }
@@ -115,15 +117,7 @@ public final class MinDoubleGroupingAggregatorFunction implements GroupingAggreg
   private void addRawInput(int positionOffset, LongVector groups, DoubleVector values) {
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       int groupId = Math.toIntExact(groups.getLong(groupPosition));
-      state.set(MinDoubleAggregator.combine(state.getOrDefault(groupId), values.getDouble(groupPosition + positionOffset)), groupId);
-    }
-  }
-
-  private void addRawInputAllNulls(int positionOffset, LongVector groups, Block values) {
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      int groupId = Math.toIntExact(groups.getLong(groupPosition));
-      assert values.isNull(groupPosition + positionOffset);
-      state.putNull(groupPosition + positionOffset);
+      state.set(groupId, MinDoubleAggregator.combine(state.getOrDefault(groupId), values.getDouble(groupPosition + positionOffset)));
     }
   }
 
@@ -137,13 +131,12 @@ public final class MinDoubleGroupingAggregatorFunction implements GroupingAggreg
       for (int g = groupStart; g < groupEnd; g++) {
         int groupId = Math.toIntExact(groups.getLong(g));
         if (values.isNull(groupPosition + positionOffset)) {
-          state.putNull(groupId);
           continue;
         }
         int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
         int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
         for (int v = valuesStart; v < valuesEnd; v++) {
-          state.set(MinDoubleAggregator.combine(state.getOrDefault(groupId), values.getDouble(v)), groupId);
+          state.set(groupId, MinDoubleAggregator.combine(state.getOrDefault(groupId), values.getDouble(v)));
         }
       }
     }
@@ -158,28 +151,14 @@ public final class MinDoubleGroupingAggregatorFunction implements GroupingAggreg
       int groupEnd = groupStart + groups.getValueCount(groupPosition);
       for (int g = groupStart; g < groupEnd; g++) {
         int groupId = Math.toIntExact(groups.getLong(g));
-        state.set(MinDoubleAggregator.combine(state.getOrDefault(groupId), values.getDouble(groupPosition + positionOffset)), groupId);
-      }
-    }
-  }
-
-  private void addRawInputAllNulls(int positionOffset, LongBlock groups, Block values) {
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      if (groups.isNull(groupPosition)) {
-        continue;
-      }
-      int groupStart = groups.getFirstValueIndex(groupPosition);
-      int groupEnd = groupStart + groups.getValueCount(groupPosition);
-      for (int g = groupStart; g < groupEnd; g++) {
-        int groupId = Math.toIntExact(groups.getLong(g));
-        assert values.isNull(groupPosition + positionOffset);
-        state.putNull(groupPosition + positionOffset);
+        state.set(groupId, MinDoubleAggregator.combine(state.getOrDefault(groupId), values.getDouble(groupPosition + positionOffset)));
       }
     }
   }
 
   @Override
   public void addIntermediateInput(int positionOffset, LongVector groups, Page page) {
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     assert channels.size() == intermediateBlockCount();
     DoubleVector min = page.<DoubleBlock>getBlock(channels.get(0)).asVector();
     BooleanVector seen = page.<BooleanBlock>getBlock(channels.get(1)).asVector();
@@ -187,9 +166,7 @@ public final class MinDoubleGroupingAggregatorFunction implements GroupingAggreg
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       int groupId = Math.toIntExact(groups.getLong(groupPosition));
       if (seen.getBoolean(groupPosition + positionOffset)) {
-        state.set(MinDoubleAggregator.combine(state.getOrDefault(groupId), min.getDouble(groupPosition + positionOffset)), groupId);
-      } else {
-        state.putNull(groupId);
+        state.set(groupId, MinDoubleAggregator.combine(min.getDouble(groupPosition + positionOffset), state.getOrDefault(groupId)));
       }
     }
   }
@@ -200,10 +177,9 @@ public final class MinDoubleGroupingAggregatorFunction implements GroupingAggreg
       throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass());
     }
     DoubleArrayState inState = ((MinDoubleGroupingAggregatorFunction) input).state;
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     if (inState.hasValue(position)) {
-      state.set(MinDoubleAggregator.combine(state.getOrDefault(groupId), inState.get(position)), groupId);
-    } else {
-      state.putNull(groupId);
+      state.set(groupId, MinDoubleAggregator.combine(state.getOrDefault(groupId), inState.get(position)));
     }
   }
 

+ 14 - 38
x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinIntGroupingAggregatorFunction.java

@@ -53,24 +53,27 @@ public final class MinIntGroupingAggregatorFunction implements GroupingAggregato
   }
 
   @Override
-  public GroupingAggregatorFunction.AddInput prepareProcessPage(Page page) {
+  public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds,
+      Page page) {
     Block uncastValuesBlock = page.getBlock(channels.get(0));
     if (uncastValuesBlock.areAllValuesNull()) {
+      state.enableGroupIdTracking(seenGroupIds);
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
 
         @Override
         public void add(int positionOffset, LongVector groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
       };
     }
     IntBlock valuesBlock = (IntBlock) uncastValuesBlock;
     IntVector valuesVector = valuesBlock.asVector();
     if (valuesVector == null) {
+      if (valuesBlock.mayHaveNulls()) {
+        state.enableGroupIdTracking(seenGroupIds);
+      }
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
@@ -100,13 +103,12 @@ public final class MinIntGroupingAggregatorFunction implements GroupingAggregato
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       int groupId = Math.toIntExact(groups.getLong(groupPosition));
       if (values.isNull(groupPosition + positionOffset)) {
-        state.putNull(groupId);
         continue;
       }
       int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
       int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
       for (int v = valuesStart; v < valuesEnd; v++) {
-        state.set(MinIntAggregator.combine(state.getOrDefault(groupId), values.getInt(v)), groupId);
+        state.set(groupId, MinIntAggregator.combine(state.getOrDefault(groupId), values.getInt(v)));
       }
     }
   }
@@ -114,15 +116,7 @@ public final class MinIntGroupingAggregatorFunction implements GroupingAggregato
   private void addRawInput(int positionOffset, LongVector groups, IntVector values) {
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       int groupId = Math.toIntExact(groups.getLong(groupPosition));
-      state.set(MinIntAggregator.combine(state.getOrDefault(groupId), values.getInt(groupPosition + positionOffset)), groupId);
-    }
-  }
-
-  private void addRawInputAllNulls(int positionOffset, LongVector groups, Block values) {
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      int groupId = Math.toIntExact(groups.getLong(groupPosition));
-      assert values.isNull(groupPosition + positionOffset);
-      state.putNull(groupPosition + positionOffset);
+      state.set(groupId, MinIntAggregator.combine(state.getOrDefault(groupId), values.getInt(groupPosition + positionOffset)));
     }
   }
 
@@ -136,13 +130,12 @@ public final class MinIntGroupingAggregatorFunction implements GroupingAggregato
       for (int g = groupStart; g < groupEnd; g++) {
         int groupId = Math.toIntExact(groups.getLong(g));
         if (values.isNull(groupPosition + positionOffset)) {
-          state.putNull(groupId);
           continue;
         }
         int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
         int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
         for (int v = valuesStart; v < valuesEnd; v++) {
-          state.set(MinIntAggregator.combine(state.getOrDefault(groupId), values.getInt(v)), groupId);
+          state.set(groupId, MinIntAggregator.combine(state.getOrDefault(groupId), values.getInt(v)));
         }
       }
     }
@@ -157,28 +150,14 @@ public final class MinIntGroupingAggregatorFunction implements GroupingAggregato
       int groupEnd = groupStart + groups.getValueCount(groupPosition);
       for (int g = groupStart; g < groupEnd; g++) {
         int groupId = Math.toIntExact(groups.getLong(g));
-        state.set(MinIntAggregator.combine(state.getOrDefault(groupId), values.getInt(groupPosition + positionOffset)), groupId);
-      }
-    }
-  }
-
-  private void addRawInputAllNulls(int positionOffset, LongBlock groups, Block values) {
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      if (groups.isNull(groupPosition)) {
-        continue;
-      }
-      int groupStart = groups.getFirstValueIndex(groupPosition);
-      int groupEnd = groupStart + groups.getValueCount(groupPosition);
-      for (int g = groupStart; g < groupEnd; g++) {
-        int groupId = Math.toIntExact(groups.getLong(g));
-        assert values.isNull(groupPosition + positionOffset);
-        state.putNull(groupPosition + positionOffset);
+        state.set(groupId, MinIntAggregator.combine(state.getOrDefault(groupId), values.getInt(groupPosition + positionOffset)));
       }
     }
   }
 
   @Override
   public void addIntermediateInput(int positionOffset, LongVector groups, Page page) {
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     assert channels.size() == intermediateBlockCount();
     IntVector min = page.<IntBlock>getBlock(channels.get(0)).asVector();
     BooleanVector seen = page.<BooleanBlock>getBlock(channels.get(1)).asVector();
@@ -186,9 +165,7 @@ public final class MinIntGroupingAggregatorFunction implements GroupingAggregato
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       int groupId = Math.toIntExact(groups.getLong(groupPosition));
       if (seen.getBoolean(groupPosition + positionOffset)) {
-        state.set(MinIntAggregator.combine(state.getOrDefault(groupId), min.getInt(groupPosition + positionOffset)), groupId);
-      } else {
-        state.putNull(groupId);
+        state.set(groupId, MinIntAggregator.combine(min.getInt(groupPosition + positionOffset), state.getOrDefault(groupId)));
       }
     }
   }
@@ -199,10 +176,9 @@ public final class MinIntGroupingAggregatorFunction implements GroupingAggregato
       throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass());
     }
     IntArrayState inState = ((MinIntGroupingAggregatorFunction) input).state;
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     if (inState.hasValue(position)) {
-      state.set(MinIntAggregator.combine(state.getOrDefault(groupId), inState.get(position)), groupId);
-    } else {
-      state.putNull(groupId);
+      state.set(groupId, MinIntAggregator.combine(state.getOrDefault(groupId), inState.get(position)));
     }
   }
 

+ 14 - 38
x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinLongGroupingAggregatorFunction.java

@@ -52,24 +52,27 @@ public final class MinLongGroupingAggregatorFunction implements GroupingAggregat
   }
 
   @Override
-  public GroupingAggregatorFunction.AddInput prepareProcessPage(Page page) {
+  public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds,
+      Page page) {
     Block uncastValuesBlock = page.getBlock(channels.get(0));
     if (uncastValuesBlock.areAllValuesNull()) {
+      state.enableGroupIdTracking(seenGroupIds);
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
 
         @Override
         public void add(int positionOffset, LongVector groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
       };
     }
     LongBlock valuesBlock = (LongBlock) uncastValuesBlock;
     LongVector valuesVector = valuesBlock.asVector();
     if (valuesVector == null) {
+      if (valuesBlock.mayHaveNulls()) {
+        state.enableGroupIdTracking(seenGroupIds);
+      }
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
@@ -99,13 +102,12 @@ public final class MinLongGroupingAggregatorFunction implements GroupingAggregat
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       int groupId = Math.toIntExact(groups.getLong(groupPosition));
       if (values.isNull(groupPosition + positionOffset)) {
-        state.putNull(groupId);
         continue;
       }
       int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
       int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
       for (int v = valuesStart; v < valuesEnd; v++) {
-        state.set(MinLongAggregator.combine(state.getOrDefault(groupId), values.getLong(v)), groupId);
+        state.set(groupId, MinLongAggregator.combine(state.getOrDefault(groupId), values.getLong(v)));
       }
     }
   }
@@ -113,15 +115,7 @@ public final class MinLongGroupingAggregatorFunction implements GroupingAggregat
   private void addRawInput(int positionOffset, LongVector groups, LongVector values) {
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       int groupId = Math.toIntExact(groups.getLong(groupPosition));
-      state.set(MinLongAggregator.combine(state.getOrDefault(groupId), values.getLong(groupPosition + positionOffset)), groupId);
-    }
-  }
-
-  private void addRawInputAllNulls(int positionOffset, LongVector groups, Block values) {
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      int groupId = Math.toIntExact(groups.getLong(groupPosition));
-      assert values.isNull(groupPosition + positionOffset);
-      state.putNull(groupPosition + positionOffset);
+      state.set(groupId, MinLongAggregator.combine(state.getOrDefault(groupId), values.getLong(groupPosition + positionOffset)));
     }
   }
 
@@ -135,13 +129,12 @@ public final class MinLongGroupingAggregatorFunction implements GroupingAggregat
       for (int g = groupStart; g < groupEnd; g++) {
         int groupId = Math.toIntExact(groups.getLong(g));
         if (values.isNull(groupPosition + positionOffset)) {
-          state.putNull(groupId);
           continue;
         }
         int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
         int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
         for (int v = valuesStart; v < valuesEnd; v++) {
-          state.set(MinLongAggregator.combine(state.getOrDefault(groupId), values.getLong(v)), groupId);
+          state.set(groupId, MinLongAggregator.combine(state.getOrDefault(groupId), values.getLong(v)));
         }
       }
     }
@@ -156,28 +149,14 @@ public final class MinLongGroupingAggregatorFunction implements GroupingAggregat
       int groupEnd = groupStart + groups.getValueCount(groupPosition);
       for (int g = groupStart; g < groupEnd; g++) {
         int groupId = Math.toIntExact(groups.getLong(g));
-        state.set(MinLongAggregator.combine(state.getOrDefault(groupId), values.getLong(groupPosition + positionOffset)), groupId);
-      }
-    }
-  }
-
-  private void addRawInputAllNulls(int positionOffset, LongBlock groups, Block values) {
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      if (groups.isNull(groupPosition)) {
-        continue;
-      }
-      int groupStart = groups.getFirstValueIndex(groupPosition);
-      int groupEnd = groupStart + groups.getValueCount(groupPosition);
-      for (int g = groupStart; g < groupEnd; g++) {
-        int groupId = Math.toIntExact(groups.getLong(g));
-        assert values.isNull(groupPosition + positionOffset);
-        state.putNull(groupPosition + positionOffset);
+        state.set(groupId, MinLongAggregator.combine(state.getOrDefault(groupId), values.getLong(groupPosition + positionOffset)));
       }
     }
   }
 
   @Override
   public void addIntermediateInput(int positionOffset, LongVector groups, Page page) {
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     assert channels.size() == intermediateBlockCount();
     LongVector min = page.<LongBlock>getBlock(channels.get(0)).asVector();
     BooleanVector seen = page.<BooleanBlock>getBlock(channels.get(1)).asVector();
@@ -185,9 +164,7 @@ public final class MinLongGroupingAggregatorFunction implements GroupingAggregat
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       int groupId = Math.toIntExact(groups.getLong(groupPosition));
       if (seen.getBoolean(groupPosition + positionOffset)) {
-        state.set(MinLongAggregator.combine(state.getOrDefault(groupId), min.getLong(groupPosition + positionOffset)), groupId);
-      } else {
-        state.putNull(groupId);
+        state.set(groupId, MinLongAggregator.combine(min.getLong(groupPosition + positionOffset), state.getOrDefault(groupId)));
       }
     }
   }
@@ -198,10 +175,9 @@ public final class MinLongGroupingAggregatorFunction implements GroupingAggregat
       throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass());
     }
     LongArrayState inState = ((MinLongGroupingAggregatorFunction) input).state;
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     if (inState.hasValue(position)) {
-      state.set(MinLongAggregator.combine(state.getOrDefault(groupId), inState.get(position)), groupId);
-    } else {
-      state.putNull(groupId);
+      state.set(groupId, MinLongAggregator.combine(state.getOrDefault(groupId), inState.get(position)));
     }
   }
 

+ 8 - 28
x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileDoubleGroupingAggregatorFunction.java

@@ -61,24 +61,27 @@ public final class PercentileDoubleGroupingAggregatorFunction implements Groupin
   }
 
   @Override
-  public GroupingAggregatorFunction.AddInput prepareProcessPage(Page page) {
+  public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds,
+      Page page) {
     Block uncastValuesBlock = page.getBlock(channels.get(0));
     if (uncastValuesBlock.areAllValuesNull()) {
+      state.enableGroupIdTracking(seenGroupIds);
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
 
         @Override
         public void add(int positionOffset, LongVector groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
       };
     }
     DoubleBlock valuesBlock = (DoubleBlock) uncastValuesBlock;
     DoubleVector valuesVector = valuesBlock.asVector();
     if (valuesVector == null) {
+      if (valuesBlock.mayHaveNulls()) {
+        state.enableGroupIdTracking(seenGroupIds);
+      }
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
@@ -108,7 +111,6 @@ public final class PercentileDoubleGroupingAggregatorFunction implements Groupin
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       int groupId = Math.toIntExact(groups.getLong(groupPosition));
       if (values.isNull(groupPosition + positionOffset)) {
-        state.putNull(groupId);
         continue;
       }
       int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
@@ -126,14 +128,6 @@ public final class PercentileDoubleGroupingAggregatorFunction implements Groupin
     }
   }
 
-  private void addRawInputAllNulls(int positionOffset, LongVector groups, Block values) {
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      int groupId = Math.toIntExact(groups.getLong(groupPosition));
-      assert values.isNull(groupPosition + positionOffset);
-      state.putNull(groupPosition + positionOffset);
-    }
-  }
-
   private void addRawInput(int positionOffset, LongBlock groups, DoubleBlock values) {
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       if (groups.isNull(groupPosition)) {
@@ -144,7 +138,6 @@ public final class PercentileDoubleGroupingAggregatorFunction implements Groupin
       for (int g = groupStart; g < groupEnd; g++) {
         int groupId = Math.toIntExact(groups.getLong(g));
         if (values.isNull(groupPosition + positionOffset)) {
-          state.putNull(groupId);
           continue;
         }
         int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
@@ -170,23 +163,9 @@ public final class PercentileDoubleGroupingAggregatorFunction implements Groupin
     }
   }
 
-  private void addRawInputAllNulls(int positionOffset, LongBlock groups, Block values) {
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      if (groups.isNull(groupPosition)) {
-        continue;
-      }
-      int groupStart = groups.getFirstValueIndex(groupPosition);
-      int groupEnd = groupStart + groups.getValueCount(groupPosition);
-      for (int g = groupStart; g < groupEnd; g++) {
-        int groupId = Math.toIntExact(groups.getLong(g));
-        assert values.isNull(groupPosition + positionOffset);
-        state.putNull(groupPosition + positionOffset);
-      }
-    }
-  }
-
   @Override
   public void addIntermediateInput(int positionOffset, LongVector groups, Page page) {
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     assert channels.size() == intermediateBlockCount();
     BytesRefVector quart = page.<BytesRefBlock>getBlock(channels.get(0)).asVector();
     BytesRef scratch = new BytesRef();
@@ -202,6 +181,7 @@ public final class PercentileDoubleGroupingAggregatorFunction implements Groupin
       throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass());
     }
     QuantileStates.GroupingState inState = ((PercentileDoubleGroupingAggregatorFunction) input).state;
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     PercentileDoubleAggregator.combineStates(state, groupId, inState, position);
   }
 

+ 8 - 28
x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileIntGroupingAggregatorFunction.java

@@ -60,24 +60,27 @@ public final class PercentileIntGroupingAggregatorFunction implements GroupingAg
   }
 
   @Override
-  public GroupingAggregatorFunction.AddInput prepareProcessPage(Page page) {
+  public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds,
+      Page page) {
     Block uncastValuesBlock = page.getBlock(channels.get(0));
     if (uncastValuesBlock.areAllValuesNull()) {
+      state.enableGroupIdTracking(seenGroupIds);
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
 
         @Override
         public void add(int positionOffset, LongVector groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
       };
     }
     IntBlock valuesBlock = (IntBlock) uncastValuesBlock;
     IntVector valuesVector = valuesBlock.asVector();
     if (valuesVector == null) {
+      if (valuesBlock.mayHaveNulls()) {
+        state.enableGroupIdTracking(seenGroupIds);
+      }
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
@@ -107,7 +110,6 @@ public final class PercentileIntGroupingAggregatorFunction implements GroupingAg
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       int groupId = Math.toIntExact(groups.getLong(groupPosition));
       if (values.isNull(groupPosition + positionOffset)) {
-        state.putNull(groupId);
         continue;
       }
       int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
@@ -125,14 +127,6 @@ public final class PercentileIntGroupingAggregatorFunction implements GroupingAg
     }
   }
 
-  private void addRawInputAllNulls(int positionOffset, LongVector groups, Block values) {
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      int groupId = Math.toIntExact(groups.getLong(groupPosition));
-      assert values.isNull(groupPosition + positionOffset);
-      state.putNull(groupPosition + positionOffset);
-    }
-  }
-
   private void addRawInput(int positionOffset, LongBlock groups, IntBlock values) {
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       if (groups.isNull(groupPosition)) {
@@ -143,7 +137,6 @@ public final class PercentileIntGroupingAggregatorFunction implements GroupingAg
       for (int g = groupStart; g < groupEnd; g++) {
         int groupId = Math.toIntExact(groups.getLong(g));
         if (values.isNull(groupPosition + positionOffset)) {
-          state.putNull(groupId);
           continue;
         }
         int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
@@ -169,23 +162,9 @@ public final class PercentileIntGroupingAggregatorFunction implements GroupingAg
     }
   }
 
-  private void addRawInputAllNulls(int positionOffset, LongBlock groups, Block values) {
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      if (groups.isNull(groupPosition)) {
-        continue;
-      }
-      int groupStart = groups.getFirstValueIndex(groupPosition);
-      int groupEnd = groupStart + groups.getValueCount(groupPosition);
-      for (int g = groupStart; g < groupEnd; g++) {
-        int groupId = Math.toIntExact(groups.getLong(g));
-        assert values.isNull(groupPosition + positionOffset);
-        state.putNull(groupPosition + positionOffset);
-      }
-    }
-  }
-
   @Override
   public void addIntermediateInput(int positionOffset, LongVector groups, Page page) {
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     assert channels.size() == intermediateBlockCount();
     BytesRefVector quart = page.<BytesRefBlock>getBlock(channels.get(0)).asVector();
     BytesRef scratch = new BytesRef();
@@ -201,6 +180,7 @@ public final class PercentileIntGroupingAggregatorFunction implements GroupingAg
       throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass());
     }
     QuantileStates.GroupingState inState = ((PercentileIntGroupingAggregatorFunction) input).state;
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     PercentileIntAggregator.combineStates(state, groupId, inState, position);
   }
 

+ 8 - 28
x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/PercentileLongGroupingAggregatorFunction.java

@@ -59,24 +59,27 @@ public final class PercentileLongGroupingAggregatorFunction implements GroupingA
   }
 
   @Override
-  public GroupingAggregatorFunction.AddInput prepareProcessPage(Page page) {
+  public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds,
+      Page page) {
     Block uncastValuesBlock = page.getBlock(channels.get(0));
     if (uncastValuesBlock.areAllValuesNull()) {
+      state.enableGroupIdTracking(seenGroupIds);
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
 
         @Override
         public void add(int positionOffset, LongVector groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
       };
     }
     LongBlock valuesBlock = (LongBlock) uncastValuesBlock;
     LongVector valuesVector = valuesBlock.asVector();
     if (valuesVector == null) {
+      if (valuesBlock.mayHaveNulls()) {
+        state.enableGroupIdTracking(seenGroupIds);
+      }
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
@@ -106,7 +109,6 @@ public final class PercentileLongGroupingAggregatorFunction implements GroupingA
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       int groupId = Math.toIntExact(groups.getLong(groupPosition));
       if (values.isNull(groupPosition + positionOffset)) {
-        state.putNull(groupId);
         continue;
       }
       int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
@@ -124,14 +126,6 @@ public final class PercentileLongGroupingAggregatorFunction implements GroupingA
     }
   }
 
-  private void addRawInputAllNulls(int positionOffset, LongVector groups, Block values) {
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      int groupId = Math.toIntExact(groups.getLong(groupPosition));
-      assert values.isNull(groupPosition + positionOffset);
-      state.putNull(groupPosition + positionOffset);
-    }
-  }
-
   private void addRawInput(int positionOffset, LongBlock groups, LongBlock values) {
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       if (groups.isNull(groupPosition)) {
@@ -142,7 +136,6 @@ public final class PercentileLongGroupingAggregatorFunction implements GroupingA
       for (int g = groupStart; g < groupEnd; g++) {
         int groupId = Math.toIntExact(groups.getLong(g));
         if (values.isNull(groupPosition + positionOffset)) {
-          state.putNull(groupId);
           continue;
         }
         int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
@@ -168,23 +161,9 @@ public final class PercentileLongGroupingAggregatorFunction implements GroupingA
     }
   }
 
-  private void addRawInputAllNulls(int positionOffset, LongBlock groups, Block values) {
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      if (groups.isNull(groupPosition)) {
-        continue;
-      }
-      int groupStart = groups.getFirstValueIndex(groupPosition);
-      int groupEnd = groupStart + groups.getValueCount(groupPosition);
-      for (int g = groupStart; g < groupEnd; g++) {
-        int groupId = Math.toIntExact(groups.getLong(g));
-        assert values.isNull(groupPosition + positionOffset);
-        state.putNull(groupPosition + positionOffset);
-      }
-    }
-  }
-
   @Override
   public void addIntermediateInput(int positionOffset, LongVector groups, Page page) {
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     assert channels.size() == intermediateBlockCount();
     BytesRefVector quart = page.<BytesRefBlock>getBlock(channels.get(0)).asVector();
     BytesRef scratch = new BytesRef();
@@ -200,6 +179,7 @@ public final class PercentileLongGroupingAggregatorFunction implements GroupingA
       throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass());
     }
     QuantileStates.GroupingState inState = ((PercentileLongGroupingAggregatorFunction) input).state;
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     PercentileLongAggregator.combineStates(state, groupId, inState, position);
   }
 

+ 8 - 28
x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumDoubleGroupingAggregatorFunction.java

@@ -59,24 +59,27 @@ public final class SumDoubleGroupingAggregatorFunction implements GroupingAggreg
   }
 
   @Override
-  public GroupingAggregatorFunction.AddInput prepareProcessPage(Page page) {
+  public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds,
+      Page page) {
     Block uncastValuesBlock = page.getBlock(channels.get(0));
     if (uncastValuesBlock.areAllValuesNull()) {
+      state.enableGroupIdTracking(seenGroupIds);
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
 
         @Override
         public void add(int positionOffset, LongVector groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
       };
     }
     DoubleBlock valuesBlock = (DoubleBlock) uncastValuesBlock;
     DoubleVector valuesVector = valuesBlock.asVector();
     if (valuesVector == null) {
+      if (valuesBlock.mayHaveNulls()) {
+        state.enableGroupIdTracking(seenGroupIds);
+      }
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
@@ -106,7 +109,6 @@ public final class SumDoubleGroupingAggregatorFunction implements GroupingAggreg
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       int groupId = Math.toIntExact(groups.getLong(groupPosition));
       if (values.isNull(groupPosition + positionOffset)) {
-        state.putNull(groupId);
         continue;
       }
       int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
@@ -124,14 +126,6 @@ public final class SumDoubleGroupingAggregatorFunction implements GroupingAggreg
     }
   }
 
-  private void addRawInputAllNulls(int positionOffset, LongVector groups, Block values) {
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      int groupId = Math.toIntExact(groups.getLong(groupPosition));
-      assert values.isNull(groupPosition + positionOffset);
-      state.putNull(groupPosition + positionOffset);
-    }
-  }
-
   private void addRawInput(int positionOffset, LongBlock groups, DoubleBlock values) {
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       if (groups.isNull(groupPosition)) {
@@ -142,7 +136,6 @@ public final class SumDoubleGroupingAggregatorFunction implements GroupingAggreg
       for (int g = groupStart; g < groupEnd; g++) {
         int groupId = Math.toIntExact(groups.getLong(g));
         if (values.isNull(groupPosition + positionOffset)) {
-          state.putNull(groupId);
           continue;
         }
         int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
@@ -168,23 +161,9 @@ public final class SumDoubleGroupingAggregatorFunction implements GroupingAggreg
     }
   }
 
-  private void addRawInputAllNulls(int positionOffset, LongBlock groups, Block values) {
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      if (groups.isNull(groupPosition)) {
-        continue;
-      }
-      int groupStart = groups.getFirstValueIndex(groupPosition);
-      int groupEnd = groupStart + groups.getValueCount(groupPosition);
-      for (int g = groupStart; g < groupEnd; g++) {
-        int groupId = Math.toIntExact(groups.getLong(g));
-        assert values.isNull(groupPosition + positionOffset);
-        state.putNull(groupPosition + positionOffset);
-      }
-    }
-  }
-
   @Override
   public void addIntermediateInput(int positionOffset, LongVector groups, Page page) {
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     assert channels.size() == intermediateBlockCount();
     DoubleVector value = page.<DoubleBlock>getBlock(channels.get(0)).asVector();
     DoubleVector delta = page.<DoubleBlock>getBlock(channels.get(1)).asVector();
@@ -202,6 +181,7 @@ public final class SumDoubleGroupingAggregatorFunction implements GroupingAggreg
       throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass());
     }
     SumDoubleAggregator.GroupingSumState inState = ((SumDoubleGroupingAggregatorFunction) input).state;
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     SumDoubleAggregator.combineStates(state, groupId, inState, position);
   }
 

+ 14 - 38
x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumIntGroupingAggregatorFunction.java

@@ -53,24 +53,27 @@ public final class SumIntGroupingAggregatorFunction implements GroupingAggregato
   }
 
   @Override
-  public GroupingAggregatorFunction.AddInput prepareProcessPage(Page page) {
+  public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds,
+      Page page) {
     Block uncastValuesBlock = page.getBlock(channels.get(0));
     if (uncastValuesBlock.areAllValuesNull()) {
+      state.enableGroupIdTracking(seenGroupIds);
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
 
         @Override
         public void add(int positionOffset, LongVector groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
       };
     }
     IntBlock valuesBlock = (IntBlock) uncastValuesBlock;
     IntVector valuesVector = valuesBlock.asVector();
     if (valuesVector == null) {
+      if (valuesBlock.mayHaveNulls()) {
+        state.enableGroupIdTracking(seenGroupIds);
+      }
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
@@ -100,13 +103,12 @@ public final class SumIntGroupingAggregatorFunction implements GroupingAggregato
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       int groupId = Math.toIntExact(groups.getLong(groupPosition));
       if (values.isNull(groupPosition + positionOffset)) {
-        state.putNull(groupId);
         continue;
       }
       int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
       int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
       for (int v = valuesStart; v < valuesEnd; v++) {
-        state.set(SumIntAggregator.combine(state.getOrDefault(groupId), values.getInt(v)), groupId);
+        state.set(groupId, SumIntAggregator.combine(state.getOrDefault(groupId), values.getInt(v)));
       }
     }
   }
@@ -114,15 +116,7 @@ public final class SumIntGroupingAggregatorFunction implements GroupingAggregato
   private void addRawInput(int positionOffset, LongVector groups, IntVector values) {
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       int groupId = Math.toIntExact(groups.getLong(groupPosition));
-      state.set(SumIntAggregator.combine(state.getOrDefault(groupId), values.getInt(groupPosition + positionOffset)), groupId);
-    }
-  }
-
-  private void addRawInputAllNulls(int positionOffset, LongVector groups, Block values) {
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      int groupId = Math.toIntExact(groups.getLong(groupPosition));
-      assert values.isNull(groupPosition + positionOffset);
-      state.putNull(groupPosition + positionOffset);
+      state.set(groupId, SumIntAggregator.combine(state.getOrDefault(groupId), values.getInt(groupPosition + positionOffset)));
     }
   }
 
@@ -136,13 +130,12 @@ public final class SumIntGroupingAggregatorFunction implements GroupingAggregato
       for (int g = groupStart; g < groupEnd; g++) {
         int groupId = Math.toIntExact(groups.getLong(g));
         if (values.isNull(groupPosition + positionOffset)) {
-          state.putNull(groupId);
           continue;
         }
         int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
         int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
         for (int v = valuesStart; v < valuesEnd; v++) {
-          state.set(SumIntAggregator.combine(state.getOrDefault(groupId), values.getInt(v)), groupId);
+          state.set(groupId, SumIntAggregator.combine(state.getOrDefault(groupId), values.getInt(v)));
         }
       }
     }
@@ -157,28 +150,14 @@ public final class SumIntGroupingAggregatorFunction implements GroupingAggregato
       int groupEnd = groupStart + groups.getValueCount(groupPosition);
       for (int g = groupStart; g < groupEnd; g++) {
         int groupId = Math.toIntExact(groups.getLong(g));
-        state.set(SumIntAggregator.combine(state.getOrDefault(groupId), values.getInt(groupPosition + positionOffset)), groupId);
-      }
-    }
-  }
-
-  private void addRawInputAllNulls(int positionOffset, LongBlock groups, Block values) {
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      if (groups.isNull(groupPosition)) {
-        continue;
-      }
-      int groupStart = groups.getFirstValueIndex(groupPosition);
-      int groupEnd = groupStart + groups.getValueCount(groupPosition);
-      for (int g = groupStart; g < groupEnd; g++) {
-        int groupId = Math.toIntExact(groups.getLong(g));
-        assert values.isNull(groupPosition + positionOffset);
-        state.putNull(groupPosition + positionOffset);
+        state.set(groupId, SumIntAggregator.combine(state.getOrDefault(groupId), values.getInt(groupPosition + positionOffset)));
       }
     }
   }
 
   @Override
   public void addIntermediateInput(int positionOffset, LongVector groups, Page page) {
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     assert channels.size() == intermediateBlockCount();
     LongVector sum = page.<LongBlock>getBlock(channels.get(0)).asVector();
     BooleanVector seen = page.<BooleanBlock>getBlock(channels.get(1)).asVector();
@@ -186,9 +165,7 @@ public final class SumIntGroupingAggregatorFunction implements GroupingAggregato
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       int groupId = Math.toIntExact(groups.getLong(groupPosition));
       if (seen.getBoolean(groupPosition + positionOffset)) {
-        state.set(SumIntAggregator.combine(state.getOrDefault(groupId), sum.getLong(groupPosition + positionOffset)), groupId);
-      } else {
-        state.putNull(groupId);
+        state.set(groupId, SumIntAggregator.combine(sum.getLong(groupPosition + positionOffset), state.getOrDefault(groupId)));
       }
     }
   }
@@ -199,10 +176,9 @@ public final class SumIntGroupingAggregatorFunction implements GroupingAggregato
       throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass());
     }
     LongArrayState inState = ((SumIntGroupingAggregatorFunction) input).state;
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     if (inState.hasValue(position)) {
-      state.set(SumIntAggregator.combine(state.getOrDefault(groupId), inState.get(position)), groupId);
-    } else {
-      state.putNull(groupId);
+      state.set(groupId, SumIntAggregator.combine(state.getOrDefault(groupId), inState.get(position)));
     }
   }
 

+ 14 - 38
x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunction.java

@@ -52,24 +52,27 @@ public final class SumLongGroupingAggregatorFunction implements GroupingAggregat
   }
 
   @Override
-  public GroupingAggregatorFunction.AddInput prepareProcessPage(Page page) {
+  public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds,
+      Page page) {
     Block uncastValuesBlock = page.getBlock(channels.get(0));
     if (uncastValuesBlock.areAllValuesNull()) {
+      state.enableGroupIdTracking(seenGroupIds);
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
 
         @Override
         public void add(int positionOffset, LongVector groupIds) {
-          addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock);
         }
       };
     }
     LongBlock valuesBlock = (LongBlock) uncastValuesBlock;
     LongVector valuesVector = valuesBlock.asVector();
     if (valuesVector == null) {
+      if (valuesBlock.mayHaveNulls()) {
+        state.enableGroupIdTracking(seenGroupIds);
+      }
       return new GroupingAggregatorFunction.AddInput() {
         @Override
         public void add(int positionOffset, LongBlock groupIds) {
@@ -99,13 +102,12 @@ public final class SumLongGroupingAggregatorFunction implements GroupingAggregat
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       int groupId = Math.toIntExact(groups.getLong(groupPosition));
       if (values.isNull(groupPosition + positionOffset)) {
-        state.putNull(groupId);
         continue;
       }
       int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
       int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
       for (int v = valuesStart; v < valuesEnd; v++) {
-        state.set(SumLongAggregator.combine(state.getOrDefault(groupId), values.getLong(v)), groupId);
+        state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), values.getLong(v)));
       }
     }
   }
@@ -113,15 +115,7 @@ public final class SumLongGroupingAggregatorFunction implements GroupingAggregat
   private void addRawInput(int positionOffset, LongVector groups, LongVector values) {
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       int groupId = Math.toIntExact(groups.getLong(groupPosition));
-      state.set(SumLongAggregator.combine(state.getOrDefault(groupId), values.getLong(groupPosition + positionOffset)), groupId);
-    }
-  }
-
-  private void addRawInputAllNulls(int positionOffset, LongVector groups, Block values) {
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      int groupId = Math.toIntExact(groups.getLong(groupPosition));
-      assert values.isNull(groupPosition + positionOffset);
-      state.putNull(groupPosition + positionOffset);
+      state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), values.getLong(groupPosition + positionOffset)));
     }
   }
 
@@ -135,13 +129,12 @@ public final class SumLongGroupingAggregatorFunction implements GroupingAggregat
       for (int g = groupStart; g < groupEnd; g++) {
         int groupId = Math.toIntExact(groups.getLong(g));
         if (values.isNull(groupPosition + positionOffset)) {
-          state.putNull(groupId);
           continue;
         }
         int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
         int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
         for (int v = valuesStart; v < valuesEnd; v++) {
-          state.set(SumLongAggregator.combine(state.getOrDefault(groupId), values.getLong(v)), groupId);
+          state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), values.getLong(v)));
         }
       }
     }
@@ -156,28 +149,14 @@ public final class SumLongGroupingAggregatorFunction implements GroupingAggregat
       int groupEnd = groupStart + groups.getValueCount(groupPosition);
       for (int g = groupStart; g < groupEnd; g++) {
         int groupId = Math.toIntExact(groups.getLong(g));
-        state.set(SumLongAggregator.combine(state.getOrDefault(groupId), values.getLong(groupPosition + positionOffset)), groupId);
-      }
-    }
-  }
-
-  private void addRawInputAllNulls(int positionOffset, LongBlock groups, Block values) {
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      if (groups.isNull(groupPosition)) {
-        continue;
-      }
-      int groupStart = groups.getFirstValueIndex(groupPosition);
-      int groupEnd = groupStart + groups.getValueCount(groupPosition);
-      for (int g = groupStart; g < groupEnd; g++) {
-        int groupId = Math.toIntExact(groups.getLong(g));
-        assert values.isNull(groupPosition + positionOffset);
-        state.putNull(groupPosition + positionOffset);
+        state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), values.getLong(groupPosition + positionOffset)));
       }
     }
   }
 
   @Override
   public void addIntermediateInput(int positionOffset, LongVector groups, Page page) {
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     assert channels.size() == intermediateBlockCount();
     LongVector sum = page.<LongBlock>getBlock(channels.get(0)).asVector();
     BooleanVector seen = page.<BooleanBlock>getBlock(channels.get(1)).asVector();
@@ -185,9 +164,7 @@ public final class SumLongGroupingAggregatorFunction implements GroupingAggregat
     for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
       int groupId = Math.toIntExact(groups.getLong(groupPosition));
       if (seen.getBoolean(groupPosition + positionOffset)) {
-        state.set(SumLongAggregator.combine(state.getOrDefault(groupId), sum.getLong(groupPosition + positionOffset)), groupId);
-      } else {
-        state.putNull(groupId);
+        state.set(groupId, SumLongAggregator.combine(sum.getLong(groupPosition + positionOffset), state.getOrDefault(groupId)));
       }
     }
   }
@@ -198,10 +175,9 @@ public final class SumLongGroupingAggregatorFunction implements GroupingAggregat
       throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass());
     }
     LongArrayState inState = ((SumLongGroupingAggregatorFunction) input).state;
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
     if (inState.hasValue(position)) {
-      state.set(SumLongAggregator.combine(state.getOrDefault(groupId), inState.get(position)), groupId);
-    } else {
-      state.putNull(groupId);
+      state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), inState.get(position)));
     }
   }
 

+ 53 - 0
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/AbstractArrayState.java

@@ -0,0 +1,53 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.aggregation;
+
+import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.common.util.BitArray;
+import org.elasticsearch.core.Releasable;
+import org.elasticsearch.core.Releasables;
+
+public class AbstractArrayState implements Releasable {
+    protected final BigArrays bigArrays;
+
+    private BitArray seen;
+
+    public AbstractArrayState(BigArrays bigArrays) {
+        this.bigArrays = bigArrays;
+    }
+
+    final boolean hasValue(int groupId) {
+        return seen == null || seen.get(groupId);
+    }
+
+    /**
+     * Switches this array state into tracking which group ids are set. This is
+     * idempotent and fast if already tracking so it's safe to, say, call it once
+     * for every block of values that arrives containing {@code null}.
+     */
+    final void enableGroupIdTracking(SeenGroupIds seenGroupIds) {
+        if (seen == null) {
+            seen = seenGroupIds.seenGroupIds(bigArrays);
+        }
+    }
+
+    protected final void trackGroupId(int groupId) {
+        if (trackingGroupIds()) {
+            seen.set(groupId);
+        }
+    }
+
+    protected final boolean trackingGroupIds() {
+        return seen != null;
+    }
+
+    @Override
+    public void close() {
+        Releasables.close(seen);
+    }
+}

+ 5 - 15
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountDistinctBooleanAggregator.java

@@ -102,12 +102,12 @@ public class CountDistinctBooleanAggregator {
      * This means that false values for a groupId are stored at bits[2*groupId] and
      * true values for a groupId are stored at bits[2*groupId + 1]
      */
-    static class GroupingState implements GroupingAggregatorState {
+    static class GroupingState extends AbstractArrayState implements GroupingAggregatorState {
 
         final BitArray bits;
-        int largestGroupId; // total number of groups; <= bytes.length
 
         GroupingState(BigArrays bigArrays) {
+            super(bigArrays);
             boolean success = false;
             try {
                 this.bits = new BitArray(2, bigArrays); // Start with two bits for a single groupId
@@ -120,23 +120,13 @@ public class CountDistinctBooleanAggregator {
         }
 
         void collect(int groupId, boolean v) {
-            ensureCapacity(groupId);
             bits.set(groupId * 2 + (v ? 1 : 0));
+            trackGroupId(groupId);
         }
 
         void combineStates(int currentGroupId, GroupingState state) {
-            ensureCapacity(currentGroupId);
             bits.or(state.bits);
-        }
-
-        void putNull(int groupId) {
-            ensureCapacity(groupId);
-        }
-
-        void ensureCapacity(int groupId) {
-            if (groupId > largestGroupId) {
-                largestGroupId = groupId;
-            }
+            trackGroupId(currentGroupId);
         }
 
         /** Extracts an intermediate view of the contents of this state.  */
@@ -156,7 +146,7 @@ public class CountDistinctBooleanAggregator {
 
         @Override
         public void close() {
-            Releasables.close(bits);
+            Releasables.close(bits, super::close);
         }
     }
 }

+ 41 - 36
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountGroupingAggregatorFunction.java

@@ -49,10 +49,11 @@ public class CountGroupingAggregatorFunction implements GroupingAggregatorFuncti
     }
 
     @Override
-    public AddInput prepareProcessPage(Page page) {
+    public AddInput prepareProcessPage(SeenGroupIds seenGroupIds, Page page) {
         Block valuesBlock = page.getBlock(channels.get(0));
         if (valuesBlock.areAllValuesNull()) {
-            return new AddInput() {
+            state.enableGroupIdTracking(seenGroupIds);
+            return new AddInput() { // TODO return null meaning "don't collect me" and skip those
                 @Override
                 public void add(int positionOffset, LongBlock groupIds) {}
 
@@ -62,6 +63,9 @@ public class CountGroupingAggregatorFunction implements GroupingAggregatorFuncti
         }
         Vector valuesVector = valuesBlock.asVector();
         if (valuesVector == null) {
+            if (valuesBlock.mayHaveNulls()) {
+                state.enableGroupIdTracking(seenGroupIds);
+            }
             return new AddInput() {
                 @Override
                 public void add(int positionOffset, LongBlock groupIds) {
@@ -73,19 +77,18 @@ public class CountGroupingAggregatorFunction implements GroupingAggregatorFuncti
                     addRawInput(positionOffset, groupIds, valuesBlock);
                 }
             };
-        } else {
-            return new AddInput() {
-                @Override
-                public void add(int positionOffset, LongBlock groupIds) {
-                    addRawInput(groupIds);
-                }
-
-                @Override
-                public void add(int positionOffset, LongVector groupIds) {
-                    addRawInput(groupIds);
-                }
-            };
         }
+        return new AddInput() {
+            @Override
+            public void add(int positionOffset, LongBlock groupIds) {
+                addRawInput(groupIds);
+            }
+
+            @Override
+            public void add(int positionOffset, LongVector groupIds) {
+                addRawInput(groupIds);
+            }
+        };
     }
 
     private void addRawInput(int positionOffset, LongVector groups, Block values) {
@@ -93,22 +96,15 @@ public class CountGroupingAggregatorFunction implements GroupingAggregatorFuncti
         for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++, position++) {
             int groupId = Math.toIntExact(groups.getLong(groupPosition));
             if (values.isNull(position)) {
-                state.putNull(groupId);
                 continue;
             }
-            state.increment(values.getValueCount(position), groupId);
+            state.increment(groupId, values.getValueCount(position));
         }
     }
 
-    private void addRawInput(LongVector groups) {
-        for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-            int groupId = Math.toIntExact(groups.getLong(groupPosition));
-            state.increment(1, groupId);
-        }
-    }
-
-    private void addRawInput(LongBlock groups) {
-        for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
+    private void addRawInput(int positionOffset, LongBlock groups, Block values) {
+        int position = positionOffset;
+        for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++, position++) {
             if (groups.isNull(groupPosition)) {
                 continue;
             }
@@ -116,14 +112,23 @@ public class CountGroupingAggregatorFunction implements GroupingAggregatorFuncti
             int groupEnd = groupStart + groups.getValueCount(groupPosition);
             for (int g = groupStart; g < groupEnd; g++) {
                 int groupId = Math.toIntExact(groups.getLong(g));
-                state.increment(1, groupId);
+                if (values.isNull(position)) {
+                    continue;
+                }
+                state.increment(groupId, values.getValueCount(position));
             }
         }
     }
 
-    private void addRawInput(int positionOffset, LongBlock groups, Block values) {
-        int position = positionOffset;
-        for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++, position++) {
+    private void addRawInput(LongVector groups) {
+        for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
+            int groupId = Math.toIntExact(groups.getLong(groupPosition));
+            state.increment(groupId, 1);
+        }
+    }
+
+    private void addRawInput(LongBlock groups) {
+        for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
             if (groups.isNull(groupPosition)) {
                 continue;
             }
@@ -131,11 +136,7 @@ public class CountGroupingAggregatorFunction implements GroupingAggregatorFuncti
             int groupEnd = groupStart + groups.getValueCount(groupPosition);
             for (int g = groupStart; g < groupEnd; g++) {
                 int groupId = Math.toIntExact(groups.getLong(g));
-                if (values.isNull(position)) {
-                    state.putNull(groupId);
-                    continue;
-                }
-                state.increment(values.getValueCount(position), groupId);
+                state.increment(groupId, 1);
             }
         }
     }
@@ -144,11 +145,12 @@ public class CountGroupingAggregatorFunction implements GroupingAggregatorFuncti
     public void addIntermediateInput(int positionOffset, LongVector groups, Page page) {
         assert channels.size() == intermediateBlockCount();
         assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size();
+        state.enableGroupIdTracking(new SeenGroupIds.Empty());
         LongVector count = page.<LongBlock>getBlock(channels.get(0)).asVector();
         BooleanVector seen = page.<BooleanBlock>getBlock(channels.get(1)).asVector();
         assert count.getPositionCount() == seen.getPositionCount();
         for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-            state.increment(count.getLong(groupPosition + positionOffset), Math.toIntExact(groups.getLong(groupPosition)));
+            state.increment(Math.toIntExact(groups.getLong(groupPosition)), count.getLong(groupPosition + positionOffset));
         }
     }
 
@@ -158,7 +160,10 @@ public class CountGroupingAggregatorFunction implements GroupingAggregatorFuncti
             throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass());
         }
         final LongArrayState inState = ((CountGroupingAggregatorFunction) input).state;
-        state.increment(inState.get(position), groupId);
+        state.enableGroupIdTracking(new SeenGroupIds.Empty());
+        if (inState.hasValue(position)) {
+            state.increment(groupId, inState.get(position));
+        }
     }
 
     @Override

+ 2 - 2
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregator.java

@@ -38,7 +38,7 @@ public class GroupingAggregator implements Releasable {
     /**
      * Prepare to process a single page of results.
      */
-    public GroupingAggregatorFunction.AddInput prepareProcessPage(Page page) {
+    public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, Page page) {
         if (mode.isInputPartial()) {
             return new GroupingAggregatorFunction.AddInput() {
                 @Override
@@ -52,7 +52,7 @@ public class GroupingAggregator implements Releasable {
                 }
             };
         } else {
-            return aggregatorFunction.prepareProcessPage(page);
+            return aggregatorFunction.prepareProcessPage(seenGroupIds, page);
         }
     }
 

+ 1 - 1
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunction.java

@@ -74,7 +74,7 @@ public interface GroupingAggregatorFunction extends Releasable {
      *     select an optimal path and return that path as an {@link AddInput}.
      * </p>
      */
-    AddInput prepareProcessPage(Page page);  // TODO allow returning null to opt out of the callback loop
+    AddInput prepareProcessPage(SeenGroupIds seenGroupIds, Page page);  // TODO allow returning null to opt out of the callback loop
 
     /**
      * Add data produced by {@link #evaluateIntermediate}.

+ 4 - 10
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/HllStates.java

@@ -134,16 +134,14 @@ final class HllStates {
 
         final HyperLogLogPlusPlus hll;
 
-        /**
-         * Maximum group id received. Only needed for estimating max serialization size.
-         * We won't need to do that one day and can remove this.
-         */
-        int maxGroupId;
-
         GroupingState(BigArrays bigArrays, int precision) {
             this.hll = new HyperLogLogPlusPlus(HyperLogLogPlusPlus.precisionFromThreshold(precision), bigArrays, 1);
         }
 
+        void enableGroupIdTracking(SeenGroupIds seenGroupIds) {
+            // Nothing to do
+        }
+
         void collect(int groupId, long v) {
             doCollect(groupId, BitMixer.mix64(v));
         }
@@ -169,10 +167,6 @@ final class HllStates {
             return hll.cardinality(groupId);
         }
 
-        void putNull(int groupId) {
-            maxGroupId = Math.max(maxGroupId, groupId);
-        }
-
         void merge(int groupId, AbstractHyperLogLogPlusPlus other, int otherGroup) {
             hll.merge(groupId, other, otherGroup);
         }

+ 13 - 10
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/QuantileStates.java

@@ -121,10 +121,7 @@ public final class QuantileStates {
         }
 
         private TDigestState getOrAddGroup(int groupId) {
-            if (groupId > largestGroupId) {
-                digests = bigArrays.grow(digests, groupId + 1);
-                largestGroupId = groupId;
-            }
+            digests = bigArrays.grow(digests, groupId + 1);
             TDigestState qs = digests.get(groupId);
             if (qs == null) {
                 qs = TDigestState.create(DEFAULT_COMPRESSION);
@@ -133,16 +130,18 @@ public final class QuantileStates {
             return qs;
         }
 
-        void putNull(int groupId) {
-            getOrAddGroup(groupId);
-        }
-
         void add(int groupId, double v) {
             getOrAddGroup(groupId).add(v);
         }
 
         void add(int groupId, TDigestState other) {
-            getOrAddGroup(groupId).add(other);
+            if (other != null) {
+                getOrAddGroup(groupId).add(other);
+            }
+        }
+
+        void enableGroupIdTracking(SeenGroupIds seenGroupIds) {
+            // We always enable.
         }
 
         void add(int groupId, BytesRef other) {
@@ -160,7 +159,11 @@ public final class QuantileStates {
             var builder = BytesRefBlock.newBlockBuilder(selected.getPositionCount());
             for (int i = 0; i < selected.getPositionCount(); i++) {
                 int group = selected.getInt(i);
-                builder.appendBytesRef(serializeDigest(get(group)));
+                TDigestState state = get(group);
+                if (state == null) {
+                    state = TDigestState.create(DEFAULT_COMPRESSION);
+                }
+                builder.appendBytesRef(serializeDigest(state));
             }
             blocks[offset] = builder.build();
         }

+ 38 - 0
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SeenGroupIds.java

@@ -0,0 +1,38 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.aggregation;
+
+import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.common.util.BitArray;
+
+public interface SeenGroupIds {
+    /**
+     * The grouping ids that have been seen already. This {@link BitArray} is
+     * kept and mutated by the caller so make a copy if it's something you
+     * need your own copy of.
+     */
+    BitArray seenGroupIds(BigArrays bigArrays);
+
+    record Empty() implements SeenGroupIds {
+        @Override
+        public BitArray seenGroupIds(BigArrays bigArrays) {
+            return new BitArray(1, bigArrays);
+        }
+    }
+
+    record Range(int from, int to) implements SeenGroupIds {
+        @Override
+        public BitArray seenGroupIds(BigArrays bigArrays) {
+            BitArray seen = new BitArray(to - from, bigArrays);
+            for (int i = from; i < to; i++) {
+                seen.set(i);
+            }
+            return seen;
+        }
+    }
+}

+ 15 - 52
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SumDoubleAggregator.java

@@ -8,7 +8,6 @@
 package org.elasticsearch.compute.aggregation;
 
 import org.elasticsearch.common.util.BigArrays;
-import org.elasticsearch.common.util.BitArray;
 import org.elasticsearch.common.util.DoubleArray;
 import org.elasticsearch.compute.ann.Aggregator;
 import org.elasticsearch.compute.ann.GroupingAggregator;
@@ -77,16 +76,12 @@ class SumDoubleAggregator {
     public static void combineStates(GroupingSumState current, int groupId, GroupingSumState state, int statePosition) {
         if (state.hasValue(statePosition)) {
             current.add(state.values.get(statePosition), state.deltas.get(statePosition), groupId);
-        } else {
-            current.putNull(groupId);
         }
     }
 
     public static void combineIntermediate(GroupingSumState current, int groupId, double inValue, double inDelta, boolean seen) {
         if (seen) {
             current.add(inValue, inDelta, groupId);
-        } else {
-            current.putNull(groupId);
         }
     }
 
@@ -94,22 +89,21 @@ class SumDoubleAggregator {
         assert blocks.length >= offset + 3;
         var valuesBuilder = DoubleBlock.newBlockBuilder(selected.getPositionCount());
         var deltaBuilder = DoubleBlock.newBlockBuilder(selected.getPositionCount());
-        var nullsBuilder = BooleanBlock.newBlockBuilder(selected.getPositionCount());
+        var seenBuilder = BooleanBlock.newBlockBuilder(selected.getPositionCount());
         for (int i = 0; i < selected.getPositionCount(); i++) {
             int group = selected.getInt(i);
-            valuesBuilder.appendDouble(state.values.get(group));
-            deltaBuilder.appendDouble(state.deltas.get(group));
-            if (state.seen != null) {
-                nullsBuilder.appendBoolean(state.seen.get(group));
+            if (group < state.values.size()) {
+                valuesBuilder.appendDouble(state.values.get(group));
+                deltaBuilder.appendDouble(state.deltas.get(group));
+            } else {
+                valuesBuilder.appendDouble(0);
+                deltaBuilder.appendDouble(0);
             }
+            seenBuilder.appendBoolean(state.hasValue(group));
         }
         blocks[offset + 0] = valuesBuilder.build();
         blocks[offset + 1] = deltaBuilder.build();
-        if (state.seen != null) {
-            blocks[offset + 2] = nullsBuilder.build();
-        } else {
-            blocks[offset + 2] = new ConstantBooleanVector(true, selected.getPositionCount()).asBlock();
-        }
+        blocks[offset + 2] = seenBuilder.build();
     }
 
     public static Block evaluateFinal(GroupingSumState state, IntVector selected) {
@@ -153,20 +147,14 @@ class SumDoubleAggregator {
         }
     }
 
-    static class GroupingSumState implements GroupingAggregatorState {
-        private final BigArrays bigArrays;
+    static class GroupingSumState extends AbstractArrayState implements GroupingAggregatorState {
         static final long BYTES_SIZE = Double.BYTES + Double.BYTES;
 
         DoubleArray values;
         DoubleArray deltas;
 
-        // total number of groups; <= values.length
-        int largestGroupId;
-
-        private BitArray seen;
-
         GroupingSumState(BigArrays bigArrays) {
-            this.bigArrays = bigArrays;
+            super(bigArrays);
             boolean success = false;
             try {
                 this.values = bigArrays.newDoubleArray(1);
@@ -203,37 +191,12 @@ class SumDoubleAggregator {
             double updatedValue = value + correctedSum;
             deltas.set(groupId, correctedSum - (updatedValue - value));
             values.set(groupId, updatedValue);
-            if (seen != null) {
-                seen.set(groupId);
-            }
-        }
-
-        void putNull(int groupId) {
-            if (groupId > largestGroupId) {
-                ensureCapacity(groupId);
-                largestGroupId = groupId;
-            }
-            if (seen == null) {
-                seen = new BitArray(groupId + 1, bigArrays);
-                for (int i = 0; i < groupId; i++) {
-                    seen.set(i);
-                }
-            } else {
-                // Do nothing. Null is represented by the default value of false for get(int),
-                // and any present value trumps a null value in our aggregations.
-            }
-        }
-
-        boolean hasValue(int index) {
-            return seen == null || seen.get(index);
+            trackGroupId(groupId);
         }
 
         private void ensureCapacity(int groupId) {
-            if (groupId > largestGroupId) {
-                largestGroupId = groupId;
-                values = bigArrays.grow(values, groupId + 1);
-                deltas = bigArrays.grow(deltas, groupId + 1);
-            }
+            values = bigArrays.grow(values, groupId + 1);
+            deltas = bigArrays.grow(deltas, groupId + 1);
         }
 
         @Override
@@ -243,7 +206,7 @@ class SumDoubleAggregator {
 
         @Override
         public void close() {
-            Releasables.close(values, deltas, seen);
+            Releasables.close(values, deltas, () -> super.close());
         }
     }
 }

+ 39 - 61
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ArrayState.java.st

@@ -8,7 +8,6 @@
 package org.elasticsearch.compute.aggregation;
 
 import org.elasticsearch.common.util.BigArrays;
-import org.elasticsearch.common.util.BitArray;
 import org.elasticsearch.common.util.$Type$Array;
 import org.elasticsearch.compute.data.Block;
 import org.elasticsearch.compute.data.BooleanBlock;
@@ -23,81 +22,56 @@ $endif$
 import org.elasticsearch.core.Releasables;
 
 /**
- * Aggregator state for an array of $type$s.
+ * Aggregator state for an array of $type$s. It is created in a mode where it
+ * won't track the {@code groupId}s that are sent to it and it is the
+ * responsibility of the caller to only fetch values for {@code groupId}s
+ * that it has sent using the {@code selected} parameter when building the
+ * results. This is fine when there are no {@code null} values in the input
+ * data. But once there are null values in the input data it is
+ * <strong>much</strong> more convenient to only send non-null values and
+ * the tracking built into the grouping code can't track that. In that case
+ * call {@link #enableGroupIdTracking} to transition the state into a mode
+ * where it'll track which {@code groupIds} have been written.
+ * <p>
  * This class is generated. Do not edit it.
+ * </p>
  */
-final class $Type$ArrayState implements GroupingAggregatorState {
-    private final BigArrays bigArrays;
+final class $Type$ArrayState extends AbstractArrayState implements GroupingAggregatorState {
     private final $type$ init;
 
     private $Type$Array values;
-    /**
-     * Total number of groups {@code <=} values.length.
-     */
-    private int largestIndex;
-    private BitArray nonNulls;
 
     $Type$ArrayState(BigArrays bigArrays, $type$ init) {
-        this.bigArrays = bigArrays;
+        super(bigArrays);
         this.values = bigArrays.new$Type$Array(1, false);
         this.values.set(0, init);
         this.init = init;
     }
 
-    $type$ get(int index) {
-        return values.get(index);
+    $type$ get(int groupId) {
+        return values.get(groupId);
     }
 
-    $type$ getOrDefault(int index) {
-        return index <= largestIndex ? values.get(index) : init;
+    $type$ getOrDefault(int groupId) {
+        return groupId < values.size() ? values.get(groupId) : init;
     }
 
-    void set($type$ value, int index) {
-        if (index > largestIndex) {
-            ensureCapacity(index);
-            largestIndex = index;
-        }
-        values.set(index, value);
-        if (nonNulls != null) {
-            nonNulls.set(index);
-        }
+    void set(int groupId, $type$ value) {
+        ensureCapacity(groupId);
+        values.set(groupId, value);
+        trackGroupId(groupId);
     }
 
 $if(long)$
-    void increment(long value, int index) {
-        if (index > largestIndex) {
-            ensureCapacity(index);
-            largestIndex = index;
-        }
-        values.increment(index, value);
-        if (nonNulls != null) {
-            nonNulls.set(index);
-        }
+    void increment(int groupId, long value) {
+        ensureCapacity(groupId);
+        values.increment(groupId, value);
+        trackGroupId(groupId);
     }
 $endif$
 
-    void putNull(int index) {
-        if (index > largestIndex) {
-            ensureCapacity(index);
-            largestIndex = index;
-        }
-        if (nonNulls == null) {
-            nonNulls = new BitArray(index + 1, bigArrays);
-            for (int i = 0; i < index; i++) {
-                nonNulls.set(i);
-            }
-        } else {
-            // Do nothing. Null is represented by the default value of false for get(int),
-            // and any present value trumps a null value in our aggregations.
-        }
-    }
-
-    boolean hasValue(int index) {
-        return nonNulls == null || nonNulls.get(index);
-    }
-
     Block toValuesBlock(org.elasticsearch.compute.data.IntVector selected) {
-        if (nonNulls == null) {
+        if (false == trackingGroupIds()) {
             $Type$Vector.Builder builder = $Type$Vector.newVectorBuilder(selected.getPositionCount());
             for (int i = 0; i < selected.getPositionCount(); i++) {
                 builder.append$Type$(values.get(selected.getInt(i)));
@@ -116,10 +90,10 @@ $endif$
         return builder.build();
     }
 
-    private void ensureCapacity(int position) {
-        if (position >= values.size()) {
+    private void ensureCapacity(int groupId) {
+        if (groupId >= values.size()) {
             long prevSize = values.size();
-            values = bigArrays.grow(values, position + 1);
+            values = bigArrays.grow(values, groupId + 1);
             values.fill(prevSize, values.size(), init);
         }
     }
@@ -129,18 +103,22 @@ $endif$
     public void toIntermediate(Block[] blocks, int offset, IntVector selected) {
         assert blocks.length >= offset + 2;
         var valuesBuilder = $Type$Block.newBlockBuilder(selected.getPositionCount());
-        var nullsBuilder = BooleanBlock.newBlockBuilder(selected.getPositionCount());
+        var hasValueBuilder = BooleanBlock.newBlockBuilder(selected.getPositionCount());
         for (int i = 0; i < selected.getPositionCount(); i++) {
             int group = selected.getInt(i);
-            valuesBuilder.append$Type$(values.get(group));
-            nullsBuilder.appendBoolean(hasValue(group));
+            if (group < values.size()) {
+                valuesBuilder.append$Type$(values.get(group));
+            } else {
+                valuesBuilder.append$Type$(0); // TODO can we just use null?
+            }
+            hasValueBuilder.appendBoolean(hasValue(group));
         }
         blocks[offset + 0] = valuesBuilder.build();
-        blocks[offset + 1] = nullsBuilder.build();
+        blocks[offset + 1] = hasValueBuilder.build();
     }
 
     @Override
     public void close() {
-        Releasables.close(values, nonNulls);
+        Releasables.close(values, super::close);
     }
 }

+ 7 - 1
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java

@@ -8,9 +8,11 @@
 package org.elasticsearch.compute.aggregation.blockhash;
 
 import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.common.util.BitArray;
 import org.elasticsearch.common.util.BytesRefHash;
 import org.elasticsearch.common.util.LongHash;
 import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
+import org.elasticsearch.compute.aggregation.SeenGroupIds;
 import org.elasticsearch.compute.data.Block;
 import org.elasticsearch.compute.data.ElementType;
 import org.elasticsearch.compute.data.IntVector;
@@ -27,7 +29,7 @@ import java.util.List;
  * @see LongHash
  * @see BytesRefHash
  */
-public abstract sealed class BlockHash implements Releasable //
+public abstract sealed class BlockHash implements Releasable, SeenGroupIds //
     permits BooleanBlockHash, BytesRefBlockHash, DoubleBlockHash, IntBlockHash, LongBlockHash,//
     PackedValuesBlockHash, BytesRefLongBlockHash, LongLongBlockHash {
 
@@ -51,6 +53,10 @@ public abstract sealed class BlockHash implements Releasable //
      */
     public abstract IntVector nonEmpty();
 
+    // TODO merge with nonEmpty
+    @Override
+    public abstract BitArray seenGroupIds(BigArrays bigArrays);
+
     /**
      * Creates a specialized hash table that maps one or more {@link Block}s to ids.
      * @param emitBatchSize maximum batch size to be emitted when handling combinatorial

+ 13 - 0
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BooleanBlockHash.java

@@ -7,6 +7,8 @@
 
 package org.elasticsearch.compute.aggregation.blockhash;
 
+import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.common.util.BitArray;
 import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
 import org.elasticsearch.compute.data.BooleanBlock;
 import org.elasticsearch.compute.data.BooleanVector;
@@ -76,6 +78,17 @@ final class BooleanBlockHash extends BlockHash {
         return builder.build();
     }
 
+    public BitArray seenGroupIds(BigArrays bigArrays) {
+        BitArray seen = new BitArray(2, bigArrays);
+        if (everSeen[0]) {
+            seen.set(0);
+        }
+        if (everSeen[1]) {
+            seen.set(1);
+        }
+        return seen;
+    }
+
     @Override
     public void close() {
         // Nothing to close

+ 7 - 0
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BytesRefBlockHash.java

@@ -12,9 +12,11 @@ import org.elasticsearch.common.io.stream.BytesStreamOutput;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.common.util.BitArray;
 import org.elasticsearch.common.util.BytesRefArray;
 import org.elasticsearch.common.util.BytesRefHash;
 import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
+import org.elasticsearch.compute.aggregation.SeenGroupIds;
 import org.elasticsearch.compute.data.BytesRefArrayVector;
 import org.elasticsearch.compute.data.BytesRefBlock;
 import org.elasticsearch.compute.data.BytesRefVector;
@@ -105,6 +107,11 @@ final class BytesRefBlockHash extends BlockHash {
         return IntVector.range(0, Math.toIntExact(bytesRefHash.size()));
     }
 
+    @Override
+    public BitArray seenGroupIds(BigArrays bigArrays) {
+        return new SeenGroupIds.Range(0, Math.toIntExact(bytesRefHash.size())).seenGroupIds(bigArrays);
+    }
+
     @Override
     public void close() {
         bytesRefHash.close();

+ 7 - 0
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BytesRefLongBlockHash.java

@@ -10,9 +10,11 @@ package org.elasticsearch.compute.aggregation.blockhash;
 import org.apache.lucene.util.ArrayUtil;
 import org.apache.lucene.util.BytesRef;
 import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.common.util.BitArray;
 import org.elasticsearch.common.util.BytesRefHash;
 import org.elasticsearch.common.util.LongLongHash;
 import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
+import org.elasticsearch.compute.aggregation.SeenGroupIds;
 import org.elasticsearch.compute.data.Block;
 import org.elasticsearch.compute.data.BytesRefBlock;
 import org.elasticsearch.compute.data.BytesRefVector;
@@ -176,6 +178,11 @@ final class BytesRefLongBlockHash extends BlockHash {
         }
     }
 
+    @Override
+    public BitArray seenGroupIds(BigArrays bigArrays) {
+        return new SeenGroupIds.Range(0, Math.toIntExact(finalHash.size())).seenGroupIds(bigArrays);
+    }
+
     @Override
     public IntVector nonEmpty() {
         return IntVector.range(0, Math.toIntExact(finalHash.size()));

+ 7 - 0
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/DoubleBlockHash.java

@@ -8,8 +8,10 @@
 package org.elasticsearch.compute.aggregation.blockhash;
 
 import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.common.util.BitArray;
 import org.elasticsearch.common.util.LongHash;
 import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
+import org.elasticsearch.compute.aggregation.SeenGroupIds;
 import org.elasticsearch.compute.data.DoubleArrayVector;
 import org.elasticsearch.compute.data.DoubleBlock;
 import org.elasticsearch.compute.data.DoubleVector;
@@ -72,6 +74,11 @@ final class DoubleBlockHash extends BlockHash {
         return IntVector.range(0, Math.toIntExact(longHash.size()));
     }
 
+    @Override
+    public BitArray seenGroupIds(BigArrays bigArrays) {
+        return new SeenGroupIds.Range(0, Math.toIntExact(longHash.size())).seenGroupIds(bigArrays);
+    }
+
     @Override
     public void close() {
         longHash.close();

+ 7 - 0
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/IntBlockHash.java

@@ -8,8 +8,10 @@
 package org.elasticsearch.compute.aggregation.blockhash;
 
 import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.common.util.BitArray;
 import org.elasticsearch.common.util.LongHash;
 import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
+import org.elasticsearch.compute.aggregation.SeenGroupIds;
 import org.elasticsearch.compute.data.IntArrayVector;
 import org.elasticsearch.compute.data.IntBlock;
 import org.elasticsearch.compute.data.IntVector;
@@ -69,6 +71,11 @@ final class IntBlockHash extends BlockHash {
         return IntVector.range(0, Math.toIntExact(longHash.size()));
     }
 
+    @Override
+    public BitArray seenGroupIds(BigArrays bigArrays) {
+        return new SeenGroupIds.Range(0, Math.toIntExact(longHash.size())).seenGroupIds(bigArrays);
+    }
+
     @Override
     public void close() {
         longHash.close();

+ 7 - 0
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/LongBlockHash.java

@@ -8,8 +8,10 @@
 package org.elasticsearch.compute.aggregation.blockhash;
 
 import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.common.util.BitArray;
 import org.elasticsearch.common.util.LongHash;
 import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
+import org.elasticsearch.compute.aggregation.SeenGroupIds;
 import org.elasticsearch.compute.data.IntVector;
 import org.elasticsearch.compute.data.LongArrayVector;
 import org.elasticsearch.compute.data.LongBlock;
@@ -69,6 +71,11 @@ final class LongBlockHash extends BlockHash {
         return IntVector.range(0, Math.toIntExact(longHash.size()));
     }
 
+    @Override
+    public BitArray seenGroupIds(BigArrays bigArrays) {
+        return new SeenGroupIds.Range(0, Math.toIntExact(longHash.size())).seenGroupIds(bigArrays);
+    }
+
     @Override
     public void close() {
         longHash.close();

+ 7 - 0
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/LongLongBlockHash.java

@@ -9,8 +9,10 @@ package org.elasticsearch.compute.aggregation.blockhash;
 
 import org.apache.lucene.util.ArrayUtil;
 import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.common.util.BitArray;
 import org.elasticsearch.common.util.LongLongHash;
 import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
+import org.elasticsearch.compute.aggregation.SeenGroupIds;
 import org.elasticsearch.compute.data.Block;
 import org.elasticsearch.compute.data.IntVector;
 import org.elasticsearch.compute.data.LongArrayVector;
@@ -201,6 +203,11 @@ final class LongLongBlockHash extends BlockHash {
         return IntVector.range(0, Math.toIntExact(hash.size()));
     }
 
+    @Override
+    public BitArray seenGroupIds(BigArrays bigArrays) {
+        return new SeenGroupIds.Range(0, Math.toIntExact(hash.size())).seenGroupIds(bigArrays);
+    }
+
     @Override
     public String toString() {
         return "LongLongBlockHash{channels=[" + channel1 + "," + channel2 + "], entries=" + hash.size() + "}";

+ 7 - 0
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/PackedValuesBlockHash.java

@@ -13,8 +13,10 @@ import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.BytesRefBuilder;
 import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.common.util.BitArray;
 import org.elasticsearch.common.util.BytesRefHash;
 import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
+import org.elasticsearch.compute.aggregation.SeenGroupIds;
 import org.elasticsearch.compute.data.Block;
 import org.elasticsearch.compute.data.ElementType;
 import org.elasticsearch.compute.data.IntVector;
@@ -250,6 +252,11 @@ final class PackedValuesBlockHash extends BlockHash {
         return IntVector.range(0, Math.toIntExact(bytesRefHash.size()));
     }
 
+    @Override
+    public BitArray seenGroupIds(BigArrays bigArrays) {
+        return new SeenGroupIds.Range(0, Math.toIntExact(bytesRefHash.size())).seenGroupIds(bigArrays);
+    }
+
     @Override
     public void close() {
         bytesRefHash.close();

+ 2 - 2
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java

@@ -68,10 +68,10 @@ public class HashAggregationOperator implements Operator {
         this.aggregators = new ArrayList<>(aggregators.size());
         boolean success = false;
         try {
+            this.blockHash = blockHash.get();
             for (GroupingAggregator.Factory a : aggregators) {
                 this.aggregators.add(a.apply(driverContext));
             }
-            this.blockHash = blockHash.get();
             success = true;
         } finally {
             if (success == false) {
@@ -92,7 +92,7 @@ public class HashAggregationOperator implements Operator {
 
         GroupingAggregatorFunction.AddInput[] prepared = new GroupingAggregatorFunction.AddInput[aggregators.size()];
         for (int i = 0; i < prepared.length; i++) {
-            prepared[i] = aggregators.get(i).prepareProcessPage(page);
+            prepared[i] = aggregators.get(i).prepareProcessPage(blockHash, page);
         }
 
         blockHash.add(wrapPage(page), new GroupingAggregatorFunction.AddInput() {

+ 30 - 8
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OrdinalsGroupingOperator.java

@@ -17,6 +17,8 @@ import org.elasticsearch.common.util.BitArray;
 import org.elasticsearch.compute.Describable;
 import org.elasticsearch.compute.aggregation.GroupingAggregator;
 import org.elasticsearch.compute.aggregation.GroupingAggregator.Factory;
+import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
+import org.elasticsearch.compute.aggregation.SeenGroupIds;
 import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
 import org.elasticsearch.compute.data.Block;
 import org.elasticsearch.compute.data.BytesRefBlock;
@@ -297,7 +299,7 @@ public class OrdinalsGroupingOperator implements Operator {
 
     }
 
-    static final class OrdinalSegmentAggregator implements Releasable {
+    static final class OrdinalSegmentAggregator implements Releasable, SeenGroupIds {
         private final List<GroupingAggregator> aggregators;
         private final ValuesSource.Bytes.WithOrdinals withOrdinals;
         private final LeafReaderContext leafReaderContext;
@@ -310,16 +312,29 @@ public class OrdinalsGroupingOperator implements Operator {
             LeafReaderContext leafReaderContext,
             BigArrays bigArrays
         ) throws IOException {
-            this.aggregators = aggregators;
-            this.withOrdinals = withOrdinals;
-            this.leafReaderContext = leafReaderContext;
-            final SortedSetDocValues sortedSetDocValues = withOrdinals.ordinalsValues(leafReaderContext);
-            this.currentReader = new BlockOrdinalsReader(sortedSetDocValues);
-            this.visitedOrds = new BitArray(sortedSetDocValues.getValueCount(), bigArrays);
+            boolean success = false;
+            try {
+                this.aggregators = aggregators;
+                this.withOrdinals = withOrdinals;
+                this.leafReaderContext = leafReaderContext;
+                final SortedSetDocValues sortedSetDocValues = withOrdinals.ordinalsValues(leafReaderContext);
+                this.currentReader = new BlockOrdinalsReader(sortedSetDocValues);
+                this.visitedOrds = new BitArray(sortedSetDocValues.getValueCount(), bigArrays);
+                success = true;
+            } finally {
+                if (success == false) {
+                    close();
+                }
+            }
         }
 
         void addInput(IntVector docs, Page page) {
             try {
+                GroupingAggregatorFunction.AddInput[] prepared = new GroupingAggregatorFunction.AddInput[aggregators.size()];
+                for (int i = 0; i < prepared.length; i++) {
+                    prepared[i] = aggregators.get(i).prepareProcessPage(this, page);
+                }
+
                 if (BlockOrdinalsReader.canReuse(currentReader, docs.getInt(0)) == false) {
                     currentReader = new BlockOrdinalsReader(withOrdinals.ordinalsValues(leafReaderContext));
                 }
@@ -336,7 +351,7 @@ public class OrdinalsGroupingOperator implements Operator {
                     }
                 }
                 for (GroupingAggregator aggregator : aggregators) {
-                    aggregator.prepareProcessPage(page).add(0, ordinals);
+                    aggregator.prepareProcessPage(this, page).add(0, ordinals);
                 }
             } catch (IOException e) {
                 throw new UncheckedIOException(e);
@@ -347,6 +362,13 @@ public class OrdinalsGroupingOperator implements Operator {
             return new AggregatedResultIterator(aggregators, visitedOrds, withOrdinals.ordinalsValues(leafReaderContext));
         }
 
+        @Override
+        public BitArray seenGroupIds(BigArrays bigArrays) {
+            BitArray seen = new BitArray(0, bigArrays);
+            seen.or(visitedOrds);
+            return seen;
+        }
+
         @Override
         public void close() {
             Releasables.close(visitedOrds, () -> Releasables.close(aggregators));

+ 192 - 0
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/ArrayStateTests.java

@@ -0,0 +1,192 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.aggregation;
+
+import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
+
+import org.elasticsearch.common.Randomness;
+import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.compute.data.BlockTestUtils;
+import org.elasticsearch.compute.data.ElementType;
+import org.elasticsearch.test.ESTestCase;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public class ArrayStateTests extends ESTestCase {
+    @ParametersFactory
+    public static List<Object[]> params() {
+        List<Object[]> params = new ArrayList<>();
+
+        for (boolean inOrder : new boolean[] { true, false }) {
+            params.add(new Object[] { ElementType.INT, 1000, inOrder });
+            params.add(new Object[] { ElementType.LONG, 1000, inOrder });
+            params.add(new Object[] { ElementType.DOUBLE, 1000, inOrder });
+        }
+        return params;
+    }
+
+    private final ElementType elementType;
+    private final int valueCount;
+    private final boolean inOrder;
+
+    public ArrayStateTests(ElementType elementType, int valueCount, boolean inOrder) {
+        this.elementType = elementType;
+        this.valueCount = valueCount;
+        this.inOrder = inOrder;
+    }
+
+    public void testSetNoTracking() {
+        List<Object> values = randomList(valueCount, valueCount, this::randomValue);
+
+        AbstractArrayState state = newState();
+        setAll(state, values, 0);
+        for (int i = 0; i < values.size(); i++) {
+            assertTrue(state.hasValue(i));
+            assertThat(get(state, i), equalTo(values.get(i)));
+        }
+    }
+
+    public void testSetWithoutTrackingThenSetWithTracking() {
+        List<Object> values = randomList(valueCount, valueCount, this::nullableRandomValue);
+
+        AbstractArrayState state = newState();
+        state.enableGroupIdTracking(new SeenGroupIds.Empty());
+        setAll(state, values, 0);
+        for (int i = 0; i < values.size(); i++) {
+            if (values.get(i) == null) {
+                assertFalse(state.hasValue(i));
+            } else {
+                assertTrue(state.hasValue(i));
+                assertThat(get(state, i), equalTo(values.get(i)));
+            }
+        }
+    }
+
+    public void testSetWithTracking() {
+        List<Object> withoutNulls = randomList(valueCount, valueCount, this::randomValue);
+        List<Object> withNulls = randomList(valueCount, valueCount, this::nullableRandomValue);
+
+        AbstractArrayState state = newState();
+        setAll(state, withoutNulls, 0);
+        state.enableGroupIdTracking(new SeenGroupIds.Range(0, withoutNulls.size()));
+        setAll(state, withNulls, withoutNulls.size());
+
+        for (int i = 0; i < withoutNulls.size(); i++) {
+            assertTrue(state.hasValue(i));
+            assertThat(get(state, i), equalTo(withoutNulls.get(i)));
+        }
+        for (int i = 0; i < withNulls.size(); i++) {
+            if (withNulls.get(i) == null) {
+                assertFalse(state.hasValue(i + withoutNulls.size()));
+            } else {
+                assertTrue(state.hasValue(i + withoutNulls.size()));
+                assertThat(get(state, i + withoutNulls.size()), equalTo(withNulls.get(i)));
+            }
+        }
+    }
+
+    public void testSetNotNullableThenOverwriteNullable() {
+        List<Object> first = randomList(valueCount, valueCount, this::randomValue);
+        List<Object> second = randomList(valueCount, valueCount, this::nullableRandomValue);
+
+        AbstractArrayState state = newState();
+        setAll(state, first, 0);
+        state.enableGroupIdTracking(new SeenGroupIds.Range(0, valueCount));
+        setAll(state, second, 0);
+
+        for (int i = 0; i < valueCount; i++) {
+            assertTrue(state.hasValue(i));
+            Object expected = second.get(i);
+            expected = expected == null ? first.get(i) : expected;
+            assertThat(get(state, i), equalTo(expected));
+        }
+    }
+
+    public void testSetNullableThenOverwriteNullable() {
+        List<Object> first = randomList(valueCount, valueCount, this::nullableRandomValue);
+        List<Object> second = randomList(valueCount, valueCount, this::nullableRandomValue);
+
+        AbstractArrayState state = newState();
+        state.enableGroupIdTracking(new SeenGroupIds.Empty());
+        setAll(state, first, 0);
+        setAll(state, second, 0);
+
+        for (int i = 0; i < valueCount; i++) {
+            Object expected = second.get(i);
+            expected = expected == null ? first.get(i) : expected;
+            if (expected == null) {
+                assertFalse(state.hasValue(i));
+            } else {
+                assertTrue(state.hasValue(i));
+                assertThat(get(state, i), equalTo(expected));
+            }
+        }
+    }
+
+    private record ValueAndIndex(int index, Object value) {}
+
+    private void setAll(AbstractArrayState state, List<Object> values, int offset) {
+        if (inOrder) {
+            for (int i = 0; i < values.size(); i++) {
+                if (values.get(i) != null) {
+                    set(state, i + offset, values.get(i));
+                }
+            }
+            return;
+        }
+        List<ValueAndIndex> shuffled = new ArrayList<>(values.size());
+        for (int i = 0; i < values.size(); i++) {
+            shuffled.add(new ValueAndIndex(i, values.get(i)));
+        }
+        Randomness.shuffle(shuffled);
+        for (ValueAndIndex v : shuffled) {
+            if (v.value != null) {
+                set(state, v.index + offset, v.value);
+            }
+        }
+    }
+
+    private AbstractArrayState newState() {
+        return switch (elementType) {
+            case INT -> new IntArrayState(BigArrays.NON_RECYCLING_INSTANCE, 1);
+            case LONG -> new LongArrayState(BigArrays.NON_RECYCLING_INSTANCE, 1);
+            case DOUBLE -> new DoubleArrayState(BigArrays.NON_RECYCLING_INSTANCE, 1);
+            default -> throw new IllegalArgumentException();
+        };
+    }
+
+    private void set(AbstractArrayState state, int groupdId, Object value) {
+        switch (elementType) {
+            case INT -> ((IntArrayState) state).set(groupdId, (Integer) value);
+            case LONG -> ((LongArrayState) state).set(groupdId, (Long) value);
+            case DOUBLE -> ((DoubleArrayState) state).set(groupdId, (Double) value);
+            default -> throw new IllegalArgumentException();
+        }
+    }
+
+    private Object get(AbstractArrayState state, int index) {
+        return switch (elementType) {
+            case INT -> ((IntArrayState) state).get(index);
+            case LONG -> ((LongArrayState) state).get(index);
+            case DOUBLE -> ((DoubleArrayState) state).get(index);
+            default -> throw new IllegalArgumentException();
+        };
+    }
+
+    private Object randomValue() {
+        return BlockTestUtils.randomValue(elementType);
+    }
+
+    private Object nullableRandomValue() {
+        return randomBoolean() ? null : randomValue();
+    }
+
+}

+ 34 - 5
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java

@@ -10,6 +10,7 @@ package org.elasticsearch.compute.aggregation;
 import org.apache.lucene.util.BytesRef;
 import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.common.util.BitArray;
 import org.elasticsearch.compute.data.Block;
 import org.elasticsearch.compute.data.BooleanBlock;
 import org.elasticsearch.compute.data.BytesRefBlock;
@@ -29,6 +30,7 @@ import org.elasticsearch.compute.operator.NullInsertingSourceOperator;
 import org.elasticsearch.compute.operator.Operator;
 import org.elasticsearch.compute.operator.PositionMergingSourceOperator;
 import org.elasticsearch.compute.operator.SourceOperator;
+import org.elasticsearch.core.Releasables;
 
 import java.util.ArrayList;
 import java.util.List;
@@ -408,17 +410,42 @@ public abstract class GroupingAggregatorFunctionTestCase extends ForkingOperator
             public GroupingAggregatorFunction groupingAggregator() {
                 return new GroupingAggregatorFunction() {
                     GroupingAggregatorFunction delegate = supplier.groupingAggregator();
+                    BitArray seenGroupIds = new BitArray(0, nonBreakingBigArrays());
 
                     @Override
-                    public AddInput prepareProcessPage(Page page) {
+                    public AddInput prepareProcessPage(SeenGroupIds ignoredSeenGroupIds, Page page) {
                         return new AddInput() {
-                            AddInput delegateAddInput = delegate.prepareProcessPage(page);
+                            AddInput delegateAddInput = delegate.prepareProcessPage(bigArrays -> {
+                                BitArray seen = new BitArray(0, bigArrays);
+                                seen.or(seenGroupIds);
+                                return seen;
+                            }, page);
 
                             @Override
                             public void add(int positionOffset, LongBlock groupIds) {
                                 for (int offset = 0; offset < groupIds.getPositionCount(); offset += emitChunkSize) {
                                     LongBlock.Builder builder = LongBlock.newBlockBuilder(emitChunkSize);
-                                    builder.copyFrom(groupIds, offset, Math.min(groupIds.getPositionCount(), offset + emitChunkSize));
+                                    int endP = Math.min(groupIds.getPositionCount(), offset + emitChunkSize);
+                                    for (int p = offset; p < endP; p++) {
+                                        int start = groupIds.getFirstValueIndex(p);
+                                        int count = groupIds.getValueCount(p);
+                                        switch (count) {
+                                            case 0 -> builder.appendNull();
+                                            case 1 -> {
+                                                long group = groupIds.getLong(start);
+                                                seenGroupIds.set(group);
+                                                builder.appendLong(group);
+                                            }
+                                            default -> {
+                                                int end = start + count;
+                                                for (int i = start; i < end; i++) {
+                                                    long group = groupIds.getLong(i);
+                                                    seenGroupIds.set(group);
+                                                    builder.appendLong(group);
+                                                }
+                                            }
+                                        }
+                                    }
                                     delegateAddInput.add(positionOffset + offset, builder.build());
                                 }
                             }
@@ -429,7 +456,9 @@ public abstract class GroupingAggregatorFunctionTestCase extends ForkingOperator
                                 for (int offset = 0; offset < groupIds.getPositionCount(); offset += emitChunkSize) {
                                     int count = 0;
                                     for (int i = offset; i < Math.min(groupIds.getPositionCount(), offset + emitChunkSize); i++) {
-                                        chunk[count++] = groupIds.getLong(i);
+                                        long group = groupIds.getLong(i);
+                                        seenGroupIds.set(group);
+                                        chunk[count++] = group;
                                     }
                                     delegateAddInput.add(positionOffset + offset, new LongArrayVector(chunk, count));
                                 }
@@ -471,7 +500,7 @@ public abstract class GroupingAggregatorFunctionTestCase extends ForkingOperator
 
                     @Override
                     public void close() {
-                        delegate.close();
+                        Releasables.close(delegate::close, seenGroupIds);
                     }
 
                     @Override

+ 0 - 2
x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionIT.java

@@ -179,7 +179,6 @@ public class EsqlActionIT extends AbstractEsqlIntegTestCase {
         assertEquals(expectedValues, actualValues);
     }
 
-    @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch-internal/issues/1306")
     public void testFromGroupingByNumericFieldWithNulls() {
         for (int i = 0; i < 5; i++) {
             client().prepareBulk()
@@ -249,7 +248,6 @@ public class EsqlActionIT extends AbstractEsqlIntegTestCase {
         assertThat(actualGroups, equalTo(expectedGroups));
     }
 
-    @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch-internal/issues/1306")
     public void testFromStatsGroupingByKeywordWithNulls() {
         for (int i = 0; i < 5; i++) {
             client().prepareBulk()